# SGMCMC design
_Reponse to [Remi's design document](https://hackmd.io/XFJQb_PHQ5Gi7i8Ar4oCMw?view)_
## About the kernel’s user-facing API
`state, info = sgld(rng_key, state, batch)`
I really think that the default API should do minibatching automatically. It should of course also be possible to define your own minibatches (if your data doesn’t fit in memory), but I think the default should be to not worry about that.
The reason is that I think most of the time these algorithms will be used when data fits in memory, so the default API should be aimed at that. Specifically, I think a main use-case is academics using sgmcmc algorithms in their papers, which almost always use datasets that fit in memory. Examples (these are probably some of the only current uses of the libary tbh..):
- I’ve used sgmcmcjax in my [recent paper](https://arxiv.org/abs/2105.13059) (that’s why I wrote the library!) everything fits in memory
- Kevin Murphy in his [pyprobml repo](https://github.com/probml/pyprobml/blob/5beb9b63ea26c72c88ad49ee4963ca62e9118e1b/scripts/subspace_demo.py#L148):
- [CNN example](https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/sg_mcmc_jax.ipynb#scrollTo=9hchex74Qwf9): using minibatches is still really useful even if data fits in memory
In sgmcmcjax if you want to define your own mini batches you have to use `diffusions.py`, and if you want batching done for you use use `kernels.py`.
But following if the design you suggested, an option could be to add another layer on top of the kernel you wrote:
1. The main update function would be`_ = sgld.update(rng_key, state)` which does minibatching for you
2. If you need to define your own mini batches you use the alternative update function: `_ = sgld.update_with_minibatch(rng_key, state, batch)`.
The `update()` function would simply build on `update_with_minibatch`:
```python
def update(rng_key, state):
rng_key, subkey = random.split(rng_key)
batch = get_minibatch(subkey) # this would already have access to the dataset through closures
return update_with_minibatch(rng_key, state, batch)
```
But then you would need to pass in the dataset when you build the kernel:
`sgld = blackjax.sgld(grad_estimator_fn, schedule_fn, data, batch_size)`
And if your data doesn't fit in memory, maybe pass in `data=None, batch_size=None` ?
Thoughts on this?
## Gradient factory:
Would you not prefer to have the gradient factory within the kernel factory? In `sgld = blackjax.sgld(grad_estimator_fn, schedule_fn)`.
In sgmcmcjax you build the kernel with a single ` build_sgld_kernel(dt, loglikelihood, logprior, data, batch_size)`. So here it would be: ``sgld = blackjax.sgld(loglikelihood, logprior, schedule_fn, data, batch_size)`
But then maybe it's better to separate these. For example if we do my suggestion above (doing minibatching by default), it would be better to pass in `grad_estimator_fn`; this would work for both cases of data fitting / not fitting in memory.
## Initialising state in sgld
In the “user interaction” section: sgld needs to calculate a gradient when initialising state. But this needs the mini batch; in this API the `sgld.init(position)` function doesn’t have access to it. The position and gradient will then be “aligned” (ie: the gradient corresponds to that position and not to the position of the next/previous iteration.
This thing is more annoying to figure out than for full batch methods because of the minibatching. In sgmcmcjax I did:
- The [gradient factory](https://github.com/jeremiecoullon/SGMCMCJax/blob/e6e81ea3e45fdb00b7af32c04938978864c28218/sgmcmcjax/gradient_estimation.py#L16) return `init_gradient` as well as `estimate_gradient`. Note that the former also takes in a random key to calculate the minibatch:
- This `init_gradient` function is used to build the [`init` function of the Langevin kernel](https://github.com/jeremiecoullon/SGMCMCJax/blob/e6e81ea3e45fdb00b7af32c04938978864c28218/sgmcmcjax/kernels.py#L29):
What are you thoughts on how to do this?
## Schedules
I agree that passing in the step size at every update is annoying. I don’t quite understand what you mean by “vmap over step sizes”; what use-case is this?
What might be helpful in choosing this API is considering the uses of the schedules in sgmcmc. As far as I know, there are 2 main uses cases (in papers):
1. A decreasing step size (like in the original sgld paper). You set up the schedule at the beginning and don’t worry about it again. I think that this use-case prefers the opt design. The state would have to include the iteration number.
2. [Cyclical sgmcmc](https://arxiv.org/abs/1902.03932):. The step size increases and decreases in cycles. When the step size is large you use an optimizer (and don’t keep samples), and when it’s small you keep samples. There are more details to do with tempering, but I’m not super clear on these details. For example I don’t understand why they use a separate optimiser (rather than the chosen sgmcmc algorithm) when the schedule is at large step sizes. To build this (which is a meta-algorithm) you could have a “cyclical sgmcmc factory” that takes in a kernel factory and a schedule. These 2 would be combined and would give you a kernel that includes in the state whether or not the sample corresponds to the optimisation or sampling phase. This would fit the opt design.
So overall I personally would prefer following optic’s design.