###### tags: `sampling` `vi` `monte carlo` `one-offs`
# Interacting Particle Systems for Approximate Inference
**Overview**: In this note, I will describe a problem which I came across in a [recent paper](https://arxiv.org/abs/2107.09028) which hybridises mean-field variational inference with Markov chain Monte Carlo methods. I will describe an alternative perspective on the idea which makes use of ideas from Wasserstein gradient flows and interacting particle systems, drawing on some ideas from a previous post.
## Mean-Field Variational Inference
In variational inference, one approaches the task of understanding a complex probability measure $p$ by approximating it by a much simpler probability measure $q$. An instance of this approach which is well-recognised is so-called *mean-field variational inference* (MFVI).
In MFVI, one begins by partitioning the argument of the measures in question into $M$ components, i.e. writing $x = ( x_1, \ldots, x_M )$. The next step is to assert that the approximating measure $q$ will factorise according to this partition, i.e.
\begin{align}
q(dx) = \prod_{m \in [M]} q_m (dx_m).
\end{align}
We write $\mathcal{Q}_\text{MF}$ for the collection of probability measures $q$ which admit such a factorisation.
Finally, one asserts that the "optimal" $q$ of this form will be determined by the Kullback-Leibler (KL) divergence from $p$ to $q$, i.e.
\begin{align}
q_* = \arg\min_{q \in \mathcal{Q}_\text{MF}} \left\{ \mathcal{F}(q) := \text{KL} (q, p) \right\}.
\end{align}
This approximation is of course restrictive. Nevertheless, it retains some flexibility, in that the $q_m$ are not constrained a priori to have a parametric form. Moreover, it benefits from relative tractability. Consider minimising $\mathcal{F}$ in the variable $q_m$, holding the other $q_\bar{m}$ fixed. One can write
\begin{align}
\mathcal{F}(q) &= \text{KL} (q, p) \\
&= \mathbf{E}_q \left[ \log q(x) - \log p(x) \right] \\
&= \mathbf{E}_q \left[ \sum_{\bar{m} \in [M]} \log q_\bar{m} (x_\bar{m}) - \log p(x) \right] \\
&= \sum_{\bar{m} \in [M]} \mathbf{E}_{q_\bar{m}} \left[ \log q_\bar{m} \right] - \mathbf{E}_q \left[ \log p(x) \right] \\
&= \mathbf{E}_{q_m} \left[ \log q_m (x_m) + \phi_m (x_m; q_{-m}) \right] + \sum_{\bar{m} \in [M]\setminus\{m\} } \mathbf{E}_{q_\bar{m}} \left[ \log q_\bar{m} \right]
\end{align}
where $\phi_m (x_m; q_{-m})$ is a 'variational potential' given by
\begin{align}
\phi_m (x_m; q_{-m}) = \int \left( \prod_{\bar{m} \in [M]\setminus\{m\} } q_{\bar{m}} (d\bar{x}_{\bar{m}}) \right) \left( -\log p(x_m, \bar{x}_{-m}) \right).
\end{align}
Defining
\begin{align}
Z_m (q_{-m}) &= \int \exp \left( - \phi_m (x_m; q_{-m})\right) dx_m, \\
q_{m, *} (dx_m ; q_{-m}) &= \frac{\exp \left( - \phi_m (x_m; q_{-m})\right)}{Z_m (q_{-m})} dx_m
\end{align}
one can then write that
\begin{align}
\mathcal{F}(q) &= \text{KL} (q_m, q_{m, *} ) - \log Z_m + \sum_{\bar{m} \in [M]\setminus\{m\} } \mathbf{E}_{q_\bar{m}} \left[ \log q_\bar{m} \right].
\end{align}
In particular, one can note that the latter two terms do not depend on $q_m$. One can then imagine a coordinate ascent scheme for optimising $q$ which iterates the following:
1. Select $m \in [M]$.
2. Set $q_m (dx_m) = q_{m, *} (dx_m ; q_{-m})$.
## MFVI without Conjugacy
Although the updates in the coordinate descent algorithm are nominally in closed form, in non-conjugate models it will often be the case that $q_{m, *}$ does not have a standard parametric form. This can mean that storing a representation of $q_{m, *}$ can be challenging, as can computing the integrals which define the $\phi_m$. As such, it is worthwhile to consider alternative representations of these measures, including those based on samples.
One simple approach in this spirit could be the following:
1. For $m \in [M]$, approximate
\begin{align}
q_m (dx_m) \approx \hat{q}_m (dx_m) := \frac{1}{N} \sum_{a \in [N]} \delta(x_m^a, dx_m).
\end{align}
2. Approximate the variational potential $\phi_m$ as
\begin{align}
\hat{\phi}_m (x_m) = \int \left( \prod_{\bar{m} \in [M]\setminus\{m\} } \hat{q}_{\bar{m}} (d\bar{x}_{\bar{m}}) \right) \left( -\log p(x_m, \bar{x}_{-m}) \right),
\end{align}
noting that this integral is now a finite sum.
3. Form a new approximation $\hat{q}_m$ by sampling from $\hat{q}_{m, *} (dx_m) \propto \exp ( - \hat{\phi}_m (x_m))$.
While this approach may seem appealing in principle, for high-dimensional $x_m$, we should expect the third step to be costly. Moreover, this approach necessitates forming a completely new approximation of $q_m$ at each step, sampling afresh. It might be more reasonable to instead refine our approximation by moving particles around.
I remark briefly that although the sum defining $\hat{\phi}_m$ is finite, in principle it can contain up to $N^M$ distinct terms, which may be prohibitive. In settings where $p$ can be represented by a sparse factor graph, the cost of evaluating this sum will often be considerably lower. In general, one can simply reduce the cost by computing only a subset of the terms in the sum and reweighting accordingly. One can consider various other strategies for simplifying this step, though I will not discuss them further in this note.
## Wasserstein Gradient Flows
Towards developing an implementation of the above idea which relies only on *moving* particles, I briefly outline the notion of a *Wasserstein gradient flow* (WGF). In essence, the relationship between WGFs and functionals of probability measures is akin to the relationship between standard gradient flows and functions: we specify some objective which we would like to minimise through local motions of minimal effort, and we come up with a dynamical path which might accomplish this.
Roughly speaking, define a functional $\Phi$ on the space of probability measures through an integral representation as
\begin{align}
\Phi ( q ) = \int \Psi[q](x) dx,
\end{align}
where $\Psi$ may involve $q$, its derivatives, etc. The functional derivative of $\Phi$ is defined so that for appropriate perturbations $\delta q$, it holds that
\begin{align}
\Phi ( q + \delta q ) - \Phi(q) \approx \int \frac{\delta \Phi}{\delta q}(x) \cdot \delta q(x) \, dx.
\end{align}
The Wasserstein derivative of $\Phi$ is then given by
\begin{align}
\nabla_{\mathcal{W}} \Phi(q) (x) = \nabla_x \left( \frac{\delta \Phi}{\delta q}(x)\right).
\end{align}
Now, the Wasserstein gradient flow with respect to the functional $\Phi$ is an evolution of measures, but is most easily described via a corresponding evolution of particles. Let $q_t$ be the state of the measure at time $t$. Under the WGF of $\Phi$, a particle $x_t$ drawn from $q_t$ will evolve according to
\begin{align}
\dot{x}_t = - \nabla_{\mathcal{W}} \Phi(q_t) (x_t).
\end{align}
This induces an evolution of measures through the Liouville PDE
\begin{align}
\partial_t q_t + \text{div}_x \left( q_t \nabla_{\mathcal{W}} \Phi(q_t)\right) = 0.
\end{align}
A useful example is when $\Phi(q) = \text{KL}(q, p)$, in which case:
* $\Psi(q) = q(x) \cdot \log \frac{q(x)}{p(x)} - q(x) + p(x)$
* $\frac{\delta \Phi}{\delta q}(x) = \log \frac{q(x)}{p(x)}$
* $\nabla_{\mathcal{W}} \Phi(q) (x) = \nabla_x \log \frac{q(x)}{p(x)}$
The resulting WGF dynamics are then given by
\begin{align}
\dot{x}_t = \nabla_x \log p(x) - \nabla_x \log q_t (x),
\end{align}
i.e. the particle moves towards modes of $p$, and away from modes of its current law.
It is a remarkable and useful fact that, at the level of measures, one can replace the $- \nabla_x \log q_t (x)$ component of the dynamics by an appropriate white noise process, and obtain the same evolution of measures (of course, the pathwise behaviour will be completely different). That is, if one modifies the dynamics above to
\begin{align}
dx_t = \nabla_x \log p(x) dt + \sqrt{2} dW_t,
\end{align}
then the law of the particles at time $t$ will be the same as above. Some readers will recognise this equation as the dynamics of the Overdamped Langevin diffusion.
## WGFs for Mean-Field Variational Inference
In our scenario, we actually have a functional which takes many probability measures as inputs, namely
\begin{align}
\mathcal{F} (q) = \mathcal{F} \left(\{ q_m\}_{m \in [M]} \right).
\end{align}
Making use of the latter parametrisation, we might consider running simultaneous gradient flows on all of the $q_m$. By our earlier manipulations, we can compute that
\begin{align}
\frac{\delta \mathcal{F}}{\delta q_m} = \log \frac{q_m}{q_{m,*}} = \log q_m + \phi_m + \log Z_m.
\end{align}
The corresponding WGF dynamics for $q_m$ can thus be taken as
\begin{align}
dx_{m, t} &= - \nabla \phi_{m, t} (x_{m, t}) dt + \sqrt{2} dW_{m, t} \\
\text{where} \quad \phi_{m, t} (x_m) &= \int \left( \prod_{\bar{m} \in [M]\setminus\{m\} } q_{\bar{m}, t} (d\bar{x}_{\bar{m}}) \right) \left( -\log p(x_m, \bar{x}_{-m}) \right)
\end{align}
for $m \in [M]$. One would then hope that this dynamical system converges to a steady state which is a minimiser of the original KL minimisation problem, local or otherwise.
## Particle Implementation
In order to make the above derivations practical, it is natural to approximate the measures $q_{m, t}$ by particles, i.e.
\begin{align}
q_{m, t} (dx_m) := \frac{1}{N} \sum_{a \in [N]} \delta(x_{m, t}^a, dx_m).
\end{align}
This gives rise to the following system of $NM$ coupled SDEs
\begin{align}
dx_{m, t}^a &= - \nabla \phi_{m, t} (x_{m, t}^a) dt + \sqrt{2} dW_{m, t}^a \quad \text{for } m \in [M], a \in [N]
\end{align}
where now the integral defining $\phi_{m, t}$ is a finite sum over particles. Finally, by discretising these SDEs, one can obtain a nominally practical algorithm.
## Conclusion
In this note, I have presented an approach to approximate mean-field variational inference in non-conjugate models, based around the use of Wasserstein gradient flows, and an implementation via interacting particle systems.
There are many basic questions about this procedure which I have not touched upon here, which would be of key theoretical and practical importance, such as:
* In appropriate limits, would this algorithm converge?
* If so, is there a unique limit?
* If so, does this limit have desirable properties? e.g. Do we actually recover a KL minimiser? Is it a good approximation of $p$?
* At what rate should one expect convergence to occur? e.g. exponential, polynomial, indeterminate?
Some of these questions could turn out to be very challenging indeed.

Overview: In this note, I log some basic observations about diffusion-based generative models.

8/14/2023Overview: In this note, I describe some aspects of hierarchical structure in MCMC algorithms, and how they can be of theoretical and practical relevance.

8/9/2023Overview: In this note, I discuss a recurrent question which can be used to generate research questions about methods of all sorts. I then discuss a specific instance of how this question has proved fruitful in the theory of optimisation algorithms. Methods and Approximations A nice story is that when Brad Efron derived the bootstrap, it was done in service of the question “What is the jackknife an approximation to?”. I can't help but agree that there's something quite exciting about research questions which have this same character of ''What is (this existing thing) an approximation to?''. One bonus tilt on this which I appreciate is that there can be multiple levels of approximation, and hence many answers to the same question. One well-known example is gradient descent, which can be viewed as an approximation to the proximal point method, which can then itself be viewed as an approximation to a gradient flow. There are probably even more stops along the way here. In this case, there is even the perspective that from the perspective of mathematical theory, there may be at least as much to be gained by stopping off at the proximal point interpretation, as there is from the gradient flow perspective. My experience is that generalist applied mathematicians get to grips with the gradient flow quickly, but optimisation theorists can squeeze more out of the PPM formulation. There is thus some hint that using this 'intermediate' approximation can be particularly insightful in its own right. It would be interesting to collect more examples with this character.

5/22/2023Overview: In this note, I prove Hoeffding's inequality from the perspectives of martingales and convex ordering. The Basic Construction Let $-\infty<a<x<b<\infty$, and define a random variable $M$ with law $M\left(x;a,b\right)$ by \begin{align} M=\begin{cases} a & \text{w.p. }\frac{b-x}{b-a}\ b & \text{w.p. }\frac{x-a}{b-a}. \end{cases}

5/22/2023
Published on ** HackMD**

or

By clicking below, you agree to our terms of service.

Sign in via Facebook
Sign in via Twitter
Sign in via GitHub
Sign in via Dropbox
Sign in with Wallet

Wallet
(
)

Connect another wallet
New to HackMD? Sign up