Try   HackMD

InfoMax derivation of
β
-VAE

Some notation:

  • pD
    (x): data distribution
  • qψ(z|x)
    : representation distribution
  • qψ(z)=pD(x)qψ(z|x)
    : aggregate posterior - marginal distribution of representation
    Z
  • qψ(x|z)=qψ(z|x)pD(x)qψ(z)
    : "inverted posterior"

Setup

We'll start from just the representation

qψ(z|x), with no generative model of the data. We'd like this representation to satisfy two properties:

  1. Independence: We'd like the aggregate posterior
    qψ(z)
    to exhibit coordinate-wise independence, and in particular to be close to a fixed, factoized prior distibution
    p(z)=ip(zi)
    .
  2. Maximum Infomation: We'd like the representation
    Z
    to retain as much infomation as possible about the input data
    X
    .

Note that without (1), (2) is insufficient, because then any deterministic and invertible function of

Z would satisfy 1. Similarly, without (2), (1) is insufficient because
qψ(z|x)=p(z)
would satisfy (1) but would be a pretty useless representation of the data, since
Z
doesn't depend on
X
at all

Deriving a practical objective

We can achieve a combination of (1) and (2) by optimizing an objective with the weighted combination of two terms corresponding to the two goals we set out above:

L(ψ)=KL[qψ(z)p(z)]λIqψ(z|x)pD(x)[X,Z]

Now we're going to show how this objective can be related to the

β-VAE objective. Let's look at the first term of this:

KL[qψ(z)p(z)]=Eqψ(z)logqψ(z)p(z)=Eqψ(z|x)pD(x)logqψ(z)p(z)=Eqψ(z|x)pD(x)logqψ(z)qψ(z|x)+Eqψ(z|x)pD(x)logqψ(z|x)p(z)=Eqψ(z|x)pD(x)logqψ(z)pD(x)qψ(z|x)pD(x)+Eqψ(z|x)pD(x)logqψ(z|x)p(z)=Iqψ(z|x)pD(x)[X,Z]+EpDKL[qψ(z|x)p(z)]

Putting this back together, we have that

L(ψ)=KL[qψ(z)p(z)]λIqψ(z|x)pD(x)[X,Z]=EpDKL[qψ(z|x)p(z)](λ+1)Iqψ(z|x)pD(x)[X,Z]

Now we have the KL-divergence term from the

β-VAE, we're missing the reconstruction term (and we haven't even defined the generative model
pθ(x|z)
). As we will see we can recover this term, too, by using a variational approximation to the mutual information.

Variational bound on mutual information

Note the following equality:

I[X,Z]=H[X]H[X|Z]

The first term, the entropy of

X is constant with respect to
ψ
, since we sample X from the data distribution
pD
. The second term can be bounded by the cross entropy of any classifier (using Jensen's inequality):

H[X|Z]=Eqψ(z|x)pD(x)logqψ(x|z)infθEqψ(z|x)pD(x)logpθ(x|z)

In this step, we intoduce

pθ(x|z) as an auxilliary distribution to make a variational appoximation to the mutual information.

Putting this bound back together:

L(ψ)EpDKL[qψ(z|x)p(z)](1+λ)Eqψ(z|x)pD(x)logpθ(x|z)

And this is essentially the

β-VAE objective function, where
β
is related to the previous
λ
.

Additional ramblings

Conceptually, this is interesting because here, the recognition model

qψ(z|x) is now the main object of interest.

The "latent variable model"

qψ(z|x)pD(x) parametrizes LVMs which has a marginal distribution on observable
x
that is exactly the same as the data distribution
pD
. So one can say
qψ(z|x)pD(x)
is a parametric family of latent variable models with whose likelihood is maximal.

We then ask the question, out of models of this form, which one should we choose. The generative model

pθ is introduced as an axulilliary distribution while constructing a lower bound the mutual information, but that's perhaps not the best way to do this.

So there are two families of joint distributions over latents and observable distributions here. On one hand we have

qψ(z|x)pD(x) and on the other we have
p(x)pθ(x|z)
. The
β
-VAE (or just VAE) objective tries to move these two models closer to one another. From the perspective of
qψ(z|x)pD(x)
this can be understood as trying to maximise mutual information while reproducing the prior
p(z)
. From the perspective of
p(x)pθ(x|z)
it can be understood as trying to maximise the data likelihood, i.e. to reproduce
pD
and, if the
β
-VAE objective is used, to additionally maximise information, too.

This symmetry of variational learning has been noted a few times:
ying-yang machines
adversarially learned inference