owned this note
owned this note
Published
Linked with GitHub
# BlackJAX Developer doc
## Design principle
Overal, BlackJAX is designed to be modular and easy for experiement, yet have good performance thanks to modern compiler and easy parallelization thanks to JAX.
Following JAX, BlackJAX also has a strong functional programming principle. This means you are expected to see and use function closure and partial from `functools`. You will see it in places such as:
- Function that takes another callable _and_ tensor as input and output another tensor. Since JAX transformation (`jit`, `grad`, `vmap`, etc) expects pure function that is tensor-in tensor-out, we need a way to tranform non pure function to pure function;
- Function that takes static parameter as input, and we are not interested in (or not possible as they are discrete) taking the gradient of said parameter. For example the number of leapfrog step in HMC.
Note that there are ways to mark input argument as static in some JAX transformation.
## High level signature
### Density function
In BlackJAX, we have density functions
`logprior_fn(states: pytree) -> Array`
`loglikelihood_fn(states: pytree, data=None) -> Array`
`logposterior_fn(states: pytree, data=None) -> Array`
`data=None` as default for likelihood and posterior function (which means the user already conditioned on the data). Algorithm like SGMCMC pass mini-batch data to the density function.
Notes:
- We assume all unknown parameters in `states` have support at $\mathcal{R}$, and it is the user's responsibility to make sure that (i.e., BlackJAX does not supply functionary for change of variables).
- Internally, HMC-based algorithm use `potential_fn` and it is the negative of the density functions.
- BlackJAX expects the density functions output a scalar Tensor, in cases where we need to evaluate on a batch of states (this is the cases for MCMC kernel that pulled information across multiple chains like ChEES), we apply `jax.vmap`.
### MCMC kernel
From a high level, a forward kernel takes a state $x$ in the parameter space (i.e., something you can plug into the density function) and return a new states $x'$:
$T(x) = x'$
For compuational convinence, we often store a bit more information, which result in a signature like:
```python
def kernel(rng_key: PRNGKey, kernel_state) -> kernel_state, kernel_info:
...
```
`kernel_state` is usually the state $x$ with some additional information like the log_probability value, with additional information in `kernel_info` (e.g., whether there is divergent in HMC).
As the design pattern suggest, information that are required for the kernel are stored in `kernel_state` (usually a Namedtuple), with the state in the parameter space $x$ as a required component. Usually the evaluation of the density function on $x$ is also stored in `kernel_state` (so we dont need to recompute it twice). There might be other information that changes at each step, such as step size in HMC. To archive that we usually write a function closure around the kernel initialization function, for example:
```python
kernel = lambda step_size: blackjax.nuts(
logposterior_fn, step_size, static_inv_mass_matrix
)
```
#### proposal for signature change:
basically, we currently have (HMC as an example):
```python
def base_hmc(log_prob, step_size, inv_mass_matrix,
**kwargs) -> Callable[key, state]:
...
log_pro = ...
inv_mass_matrix = ...
hmc_meta = lambda step_size: base_hmc(log_prob, step_size, inv_mass_matrix)
for i ...
# kernel is reinitialized, likely inefficient due to retracing
state, info = jit(hmc_meta(step_size))(key, state)
```
Instead, new design we have signature `kernel(rng_key, state, **parameters)`, which we can create the function closure on static arguments or jit the function with `static_argnums`.
```python
def base_hmc(log_prob) -> Callable[key, state, step_size, parameters]:
...
log_pro = ...
inv_mass_matrix = ...
hmc = jit(base_hmc(log_prob), static_argnums=3)
for i ...
# No retracing
state, info = hmc(key, state, step_size, inv_mass_matrix)
# Alternatively
hmc_meta = base_hmc(log_prob)
hmc = jit(functools.partial(hmc_meta, inv_mass_matrix=inv_mass_matrix))
for i ...
# No retracing
state, info = hmc(key, state, step_size)
```
#### Back to current design
From a high level signature, we usually have:
```python
# One step
new_state, info = kernel(rng_key, state)
# MCMC sample
rng_key = ...
state = ...
for i in range(num_samples):
rng_key, sample_key = jax.random.split(rng_key, 2)
state, info = kernel(sample_key, state)
# Store state
all_states.record(state)
# Store info
all_infos.record(info)
```
#### HMC
##### Symplectic integrator
## Writing unit tests
We use [Chex](https://github.com/deepmind/chex) to handle unit test. It provides easy functionary to test function with [different variants](https://github.com/deepmind/chex#test-variants-variantspy) (`jit`, non-`jit`, etc).