I had a really hard time learning about VAE at the beginning of my PhD. I felt very betrayed spending time deriving and memorising ELBO (the evidence lower bound objective), then seeing yet another paper that writes it in a different way. However, as I mature, my attitude towards this changed –- now I have learned to embrace the power of the seemingly infinitely many forms of ELBO.
Thinking back, this transformation really took place when I was introduced by my supervisor Sid to this great series of literature that covers the evolution of ELBO over the last 5, 6 years. Organising all of them and describing them in non-jibberish took some time, but I hope that this will serve as a frustration-free note-to-self for future revisiting to the topic, and also that it can be helpful to people out there who are feeling equally bamboozled as I was a year ago.
I will discuss the following papers (click on links for PDF), one in each section –- and trust me they each serve a purpose and tell a whole story:
Before we dive in, let's look at the most basic form of ELBO first, here it is in all of its glory:
where
If you have this memorised or tattooed on your arm, we are ready to go!
Paper discussed: ELBO surgery: yet another way to carve up the variational evidence lower bound, work by Matthew Hoffman and Matthew Johnson.
This work provides a very intuitive perspective of the VAE objective by decomposing and rewriting ELBO. For a batch of N observations
Where
So what is the point of all this? Well, what's interesting with this decomposition is that (1) average reconstruction and (2) index-code mutual information have opposing effects on the latent space:
We visualise these effects in the graph below for two observations
Fig. Visualisation of effect of term (1) and (2). Dotted lines represent inference model
This now leaves us with term (3), which is the only term that involves prior. This term regularises the aggregated posterior by prior through minimising the KL distance between
Paper Disentangling disentanglement in Variational Autoencoders also did a great job analysing and utilising the effect of these three terms for disentanglement in VAEs, and I strongly recommend that you go and have a look.
Paper discussed: Importance Weighted Autoencoders, work by Yuri Burda, Roger Grosse & Ruslan Salakhutdinov
Hopefully the previous section served as a good warm-up for this blog, and now you have a better intuition on how ELBO affects the graphical model. Now, we will move just a tat away from the original ELBO, to a more advanced K-sampled lower bound estimator: IWAE.
Importance Weighted Autoencoders (IWAE), is probably my favourite machine learning trick (and I know about 4). It is a simple and yet powerful way to improve the performance of VAEs, and you're really missing out if you went through the trouble to implement ELBO but stopped there. Here, I will talk about the formultaion of IWAE and its 3 benefits: tighter lower bound estimate, importance-weighted gradients and complex implicit distribution.
IWAE proposes a tighter estimate to
A common practice to acquire a better estimate to
IWAE simply switch the position between the sum over
It is easy to see that by Jensen's inequality,
Things become even more interesting if we look at the gradient of IWAE compared to the original ELBO:
where
So we can see that in the
However, this is not all of it –- authors in the original paper also showed that IWAE can be interpreted as standard ELBO, but with a more complex (implicit) posterior distribution
Side note: Paper Reinterpreting IWAE helped me a lot to understanding the IWAE objective, highly recommended. In addition, this blog post by Adam Kosiorek is also a very comprehensive interpretation on the topic.
Paper discussed: Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference by Geoffrey Roeder, Yuhuai Wu & David Duvenaud.
So far we discussed two variational lower bounds in details, ELBO and IWAE. Now is high time to take them off their pedestals and talk about what's wrong with them –- and as you can guess from the title of this section, this has something to do with gradient variance.
Despite my best effort to sound very excited about all this, I had definitely struggled to care about things like "gradient variance" in the past, largely because there seems to be so many different Monte Carlo gradient estimators out there. But not too long ago, I realised that there are only two very common ones that you need to care about: REINFORCE estimator and reparametrisation trick. I'm leaving some details about each of them here as a note-to-self, but here's the key thing you need to remember if you want to skip this part and get to the good stuff:
Portal to next section.
This is commonly used in Reinforcement Learning. It is named score function because it utilises this "cool little logarithm trick":
So now, when we try to estimate the gradient of some function
and now we can easily estimate the gradient by performing MC sampling –- taking
Keep in mind that this score function estimator estimator, despite being unbiased, has very large variance from multiple sources (see here in section 4.3.1 for details). It is however very flexible and places no requirement on
I assume you are faimiliar with the reparametrisation trick if you got all the way here, but I am a completionist, so here's a quick recap:
The reparametrisation trick utilises the property that for continuous distribution
The most common usage of this is seen in VAE, where instead of directly sampling from the posterior, we typically take random sample from a standard Normal distribution
This method is much less general-purpose compared to the score function estimator since it requires
Side note: For readers who're not afraid of gradients, here is a great survey paper on MC gradient estimators.
At this point we should all be familiar with reparametrisation trick used in VAEs for gradient estimation, but here we need to formalise it a bit more for the derivation in this section:
Reparametrisation trick express sample
from parametric distribution as a deterministic function of a random variable with some fixed distribution and the parameters , i.e. . For example, if is a diagonal Gaussian, then for .
We already know that reparametrisation trick (path derivative) has the benefit of lower variance for gradient estimation compared to score function. The kicker here is –- the gradient of ELBO actually contains a score function term, causing the estimator to have large variance!
To see this, we can first rewrite ELBO as the following:
We can then take the total derivative of the term within expectation w.r.t.
So we see that
So, it is not surprising to learn that the large variance of the score function term here causes problems: the authors discovered that even when the variational posterior
So what do we do here? Well, authors propose to simply drop the score function component to get an unbiased gradient estimator:
It sounds a bit wacky at first, but this approach works miracle, as authors show in this plot:
As we see clearly here that by using the path derivative only gradient, the variance of gradient estimation is much lower and
Note that this large gradient variance problem applies for any ELBO, including both standard VAE and IWAE. However, we will show in the next section that IWAE has its unique problem caused by the K multiple samples,that is –-
Paper discussed: Tight Variational Bounds are Not Necessarily Better, work by Tom Rainforth, Adam R. Kosiorek, Tuan Anh Le, Chris J. Maddison, Maximilian Igl, Frank Wood &
Yee Whye Teh
This builds on the previous Sticking the Landing paper, and discovers that the gradient variance caused by score function becomes a even bigger problem when using a multi-sample estimator like IWAE.
In here it's not just a variance problem: estimators with small expected values need proportionally smaller variance to be estimated accurately. In other words, what we really care about here is the expectation-to-variance, or signal-to-noise (SNR) ratio:
Here
Ideally we want a large SNR for the gradient estimator fo both
This tells us that while increasing the number of IWAE samples
The authors gave a very comprehensive proof to their finding, so I'm going to leave the mathy heavy lifting to the original paper :) We shall march on to the last section of this blog: an elegant solution to solve the large variance in ELBO gradient estimators –- DReG.
Paper discussed: Doubly Reparametrised Gradient Estimators for Monte Carlo Objectives, work by George Tucker, Dieterich Lawson, Shixiang Gu & Chris J. Maddison.
In section 3 we talked about the large gradient variance caused by the score function lurking in the gradient estimation, and section 4 about how this is exacerbated for IWAE. I'll put the total derivative we have seen in section 3 here as a reference, but to make it more relevant, this time we rewrite it for IWAE that uses
where
This is not much of a change from the total derivative of original ELBO, as we have mentioned in section 2 that IWAE simply weights the gradients of VAE ElBO by the relative importance of each sample
.
We have learned that one way to deal with it is to completely remove the score function term. However, is there a better way than completely discarding a term in gradient estimation?
Well obviously I wouldn't be asking this question here if the answer weren't yes –- authors in this paper proposed to reduce the variance by doing another reparametrisation on the score function term! Here's how:
Taking the score function term in the total derivative of IWAE, we can first take the
Now we can just ignore the sum and focus on what's in the expectation
I should clarify that that previously we just had the score function term, but since the expectation is over
instead of actual samples from ), it is not actually REINFORCE.
This is important because REINFORCE and reparametrisation trick are interchangable, as we see below:
If we substitute the above back into the original total derivative of IWAE, after some math montage, we can simplifying it as the following:
This is actually very easy to implement: cheeky little plug, we used this objective in our paper on multimodal VAE learning, you can find the code here that comes with a handy implementation of DReG in pytorch.
A heartfelt congratulation if you got all the way here, well done! Leave a comment if you have any question, if you find this helpful please share on twitter/facebook :)