###### tags: `sampling` `vi` `monte carlo` `one-offs`
# Interacting Particle Systems for Approximate Inference
**Overview**: In this note, I will describe a problem which I came across in a [recent paper](https://arxiv.org/abs/2107.09028) which hybridises mean-field variational inference with Markov chain Monte Carlo methods. I will describe an alternative perspective on the idea which makes use of ideas from Wasserstein gradient flows and interacting particle systems, drawing on some ideas from a previous post.
## Mean-Field Variational Inference
In variational inference, one approaches the task of understanding a complex probability measure $p$ by approximating it by a much simpler probability measure $q$. An instance of this approach which is well-recognised is so-called *mean-field variational inference* (MFVI).
In MFVI, one begins by partitioning the argument of the measures in question into $M$ components, i.e. writing $x = ( x_1, \ldots, x_M )$. The next step is to assert that the approximating measure $q$ will factorise according to this partition, i.e.
\begin{align}
q(dx) = \prod_{m \in [M]} q_m (dx_m).
\end{align}
We write $\mathcal{Q}_\text{MF}$ for the collection of probability measures $q$ which admit such a factorisation.
Finally, one asserts that the "optimal" $q$ of this form will be determined by the Kullback-Leibler (KL) divergence from $p$ to $q$, i.e.
\begin{align}
q_* = \arg\min_{q \in \mathcal{Q}_\text{MF}} \left\{ \mathcal{F}(q) := \text{KL} (q, p) \right\}.
\end{align}
This approximation is of course restrictive. Nevertheless, it retains some flexibility, in that the $q_m$ are not constrained a priori to have a parametric form. Moreover, it benefits from relative tractability. Consider minimising $\mathcal{F}$ in the variable $q_m$, holding the other $q_\bar{m}$ fixed. One can write
\begin{align}
\mathcal{F}(q) &= \text{KL} (q, p) \\
&= \mathbf{E}_q \left[ \log q(x) - \log p(x) \right] \\
&= \mathbf{E}_q \left[ \sum_{\bar{m} \in [M]} \log q_\bar{m} (x_\bar{m}) - \log p(x) \right] \\
&= \sum_{\bar{m} \in [M]} \mathbf{E}_{q_\bar{m}} \left[ \log q_\bar{m} \right] - \mathbf{E}_q \left[ \log p(x) \right] \\
&= \mathbf{E}_{q_m} \left[ \log q_m (x_m) + \phi_m (x_m; q_{-m}) \right] + \sum_{\bar{m} \in [M]\setminus\{m\} } \mathbf{E}_{q_\bar{m}} \left[ \log q_\bar{m} \right]
\end{align}
where $\phi_m (x_m; q_{-m})$ is a 'variational potential' given by
\begin{align}
\phi_m (x_m; q_{-m}) = \int \left( \prod_{\bar{m} \in [M]\setminus\{m\} } q_{\bar{m}} (d\bar{x}_{\bar{m}}) \right) \left( -\log p(x_m, \bar{x}_{-m}) \right).
\end{align}
Defining
\begin{align}
Z_m (q_{-m}) &= \int \exp \left( - \phi_m (x_m; q_{-m})\right) dx_m, \\
q_{m, *} (dx_m ; q_{-m}) &= \frac{\exp \left( - \phi_m (x_m; q_{-m})\right)}{Z_m (q_{-m})} dx_m
\end{align}
one can then write that
\begin{align}
\mathcal{F}(q) &= \text{KL} (q_m, q_{m, *} ) - \log Z_m + \sum_{\bar{m} \in [M]\setminus\{m\} } \mathbf{E}_{q_\bar{m}} \left[ \log q_\bar{m} \right].
\end{align}
In particular, one can note that the latter two terms do not depend on $q_m$. One can then imagine a coordinate ascent scheme for optimising $q$ which iterates the following:
1. Select $m \in [M]$.
2. Set $q_m (dx_m) = q_{m, *} (dx_m ; q_{-m})$.
## MFVI without Conjugacy
Although the updates in the coordinate descent algorithm are nominally in closed form, in non-conjugate models it will often be the case that $q_{m, *}$ does not have a standard parametric form. This can mean that storing a representation of $q_{m, *}$ can be challenging, as can computing the integrals which define the $\phi_m$. As such, it is worthwhile to consider alternative representations of these measures, including those based on samples.
One simple approach in this spirit could be the following:
1. For $m \in [M]$, approximate
\begin{align}
q_m (dx_m) \approx \hat{q}_m (dx_m) := \frac{1}{N} \sum_{a \in [N]} \delta(x_m^a, dx_m).
\end{align}
2. Approximate the variational potential $\phi_m$ as
\begin{align}
\hat{\phi}_m (x_m) = \int \left( \prod_{\bar{m} \in [M]\setminus\{m\} } \hat{q}_{\bar{m}} (d\bar{x}_{\bar{m}}) \right) \left( -\log p(x_m, \bar{x}_{-m}) \right),
\end{align}
noting that this integral is now a finite sum.
3. Form a new approximation $\hat{q}_m$ by sampling from $\hat{q}_{m, *} (dx_m) \propto \exp ( - \hat{\phi}_m (x_m))$.
While this approach may seem appealing in principle, for high-dimensional $x_m$, we should expect the third step to be costly. Moreover, this approach necessitates forming a completely new approximation of $q_m$ at each step, sampling afresh. It might be more reasonable to instead refine our approximation by moving particles around.
I remark briefly that although the sum defining $\hat{\phi}_m$ is finite, in principle it can contain up to $N^M$ distinct terms, which may be prohibitive. In settings where $p$ can be represented by a sparse factor graph, the cost of evaluating this sum will often be considerably lower. In general, one can simply reduce the cost by computing only a subset of the terms in the sum and reweighting accordingly. One can consider various other strategies for simplifying this step, though I will not discuss them further in this note.
## Wasserstein Gradient Flows
Towards developing an implementation of the above idea which relies only on *moving* particles, I briefly outline the notion of a *Wasserstein gradient flow* (WGF). In essence, the relationship between WGFs and functionals of probability measures is akin to the relationship between standard gradient flows and functions: we specify some objective which we would like to minimise through local motions of minimal effort, and we come up with a dynamical path which might accomplish this.
Roughly speaking, define a functional $\Phi$ on the space of probability measures through an integral representation as
\begin{align}
\Phi ( q ) = \int \Psi[q](x) dx,
\end{align}
where $\Psi$ may involve $q$, its derivatives, etc. The functional derivative of $\Phi$ is defined so that for appropriate perturbations $\delta q$, it holds that
\begin{align}
\Phi ( q + \delta q ) - \Phi(q) \approx \int \frac{\delta \Phi}{\delta q}(x) \cdot \delta q(x) \, dx.
\end{align}
The Wasserstein derivative of $\Phi$ is then given by
\begin{align}
\nabla_{\mathcal{W}} \Phi(q) (x) = \nabla_x \left( \frac{\delta \Phi}{\delta q}(x)\right).
\end{align}
Now, the Wasserstein gradient flow with respect to the functional $\Phi$ is an evolution of measures, but is most easily described via a corresponding evolution of particles. Let $q_t$ be the state of the measure at time $t$. Under the WGF of $\Phi$, a particle $x_t$ drawn from $q_t$ will evolve according to
\begin{align}
\dot{x}_t = - \nabla_{\mathcal{W}} \Phi(q_t) (x_t).
\end{align}
This induces an evolution of measures through the Liouville PDE
\begin{align}
\partial_t q_t + \text{div}_x \left( q_t \nabla_{\mathcal{W}} \Phi(q_t)\right) = 0.
\end{align}
A useful example is when $\Phi(q) = \text{KL}(q, p)$, in which case:
* $\Psi(q) = q(x) \cdot \log \frac{q(x)}{p(x)} - q(x) + p(x)$
* $\frac{\delta \Phi}{\delta q}(x) = \log \frac{q(x)}{p(x)}$
* $\nabla_{\mathcal{W}} \Phi(q) (x) = \nabla_x \log \frac{q(x)}{p(x)}$
The resulting WGF dynamics are then given by
\begin{align}
\dot{x}_t = \nabla_x \log p(x) - \nabla_x \log q_t (x),
\end{align}
i.e. the particle moves towards modes of $p$, and away from modes of its current law.
It is a remarkable and useful fact that, at the level of measures, one can replace the $- \nabla_x \log q_t (x)$ component of the dynamics by an appropriate white noise process, and obtain the same evolution of measures (of course, the pathwise behaviour will be completely different). That is, if one modifies the dynamics above to
\begin{align}
dx_t = \nabla_x \log p(x) dt + \sqrt{2} dW_t,
\end{align}
then the law of the particles at time $t$ will be the same as above. Some readers will recognise this equation as the dynamics of the Overdamped Langevin diffusion.
## WGFs for Mean-Field Variational Inference
In our scenario, we actually have a functional which takes many probability measures as inputs, namely
\begin{align}
\mathcal{F} (q) = \mathcal{F} \left(\{ q_m\}_{m \in [M]} \right).
\end{align}
Making use of the latter parametrisation, we might consider running simultaneous gradient flows on all of the $q_m$. By our earlier manipulations, we can compute that
\begin{align}
\frac{\delta \mathcal{F}}{\delta q_m} = \log \frac{q_m}{q_{m,*}} = \log q_m + \phi_m + \log Z_m.
\end{align}
The corresponding WGF dynamics for $q_m$ can thus be taken as
\begin{align}
dx_{m, t} &= - \nabla \phi_{m, t} (x_{m, t}) dt + \sqrt{2} dW_{m, t} \\
\text{where} \quad \phi_{m, t} (x_m) &= \int \left( \prod_{\bar{m} \in [M]\setminus\{m\} } q_{\bar{m}, t} (d\bar{x}_{\bar{m}}) \right) \left( -\log p(x_m, \bar{x}_{-m}) \right)
\end{align}
for $m \in [M]$. One would then hope that this dynamical system converges to a steady state which is a minimiser of the original KL minimisation problem, local or otherwise.
## Particle Implementation
In order to make the above derivations practical, it is natural to approximate the measures $q_{m, t}$ by particles, i.e.
\begin{align}
q_{m, t} (dx_m) := \frac{1}{N} \sum_{a \in [N]} \delta(x_{m, t}^a, dx_m).
\end{align}
This gives rise to the following system of $NM$ coupled SDEs
\begin{align}
dx_{m, t}^a &= - \nabla \phi_{m, t} (x_{m, t}^a) dt + \sqrt{2} dW_{m, t}^a \quad \text{for } m \in [M], a \in [N]
\end{align}
where now the integral defining $\phi_{m, t}$ is a finite sum over particles. Finally, by discretising these SDEs, one can obtain a nominally practical algorithm.
## Conclusion
In this note, I have presented an approach to approximate mean-field variational inference in non-conjugate models, based around the use of Wasserstein gradient flows, and an implementation via interacting particle systems.
There are many basic questions about this procedure which I have not touched upon here, which would be of key theoretical and practical importance, such as:
* In appropriate limits, would this algorithm converge?
* If so, is there a unique limit?
* If so, does this limit have desirable properties? e.g. Do we actually recover a KL minimiser? Is it a good approximation of $p$?
* At what rate should one expect convergence to occur? e.g. exponential, polynomial, indeterminate?
Some of these questions could turn out to be very challenging indeed.