Try   HackMD

Contrastive Multiview Coding

Code snippets are used from the official implementation.

Key idea

Different transformations of image (changing color space, segmentation, depth view etc) still have the same semantic content. Hence their representations should also be similar. Thus:

Given a pair of sensory views, a deep representation is learnt by bringing views of the same scene together in embedding space, while pushing views of different scenes apart.


Contrastive objective vs cross-view prediction

Cross-view prediction is the standard encoder decoder architecture where the loss is measured pixel-wise between the constructed output and the input. Pixel-wise loss doesn't care about which pixels are important and which pixels are not.

In constrastive objective two different inputs representing the same semantic content create two representations. The loss is measured between the two representations. This way the model has a change to learn which information to keep and which to discard while encoding an image. Thus the learned representation is better as it ignores all the noise and retains all the important information.


Contrastive learning with two views

V1 is a dataset of images with one kind of transformation (or view). V2 is a dataset of the same images but seen in a different view. One view is sampled from V1 and one view is sampled from V2. If both the views belong to the same image, we want a critic hθ(.) to give a high value. If they don't, the critic will give a low value. Here is how the visual looks like:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

The loss function is constructed like so:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →


Implementing the critic

To extract compact latent representations of v1 and v2, we employ two encoders fθ1(·) and fθ2(·) with parameters θ1 and θ2 respectively. The latent representions are extracted as z1=fθ1(v1), z2=fθ2(v2). On top of these features, the score is computed as the exponential of a bivariate function of z1 and z1, e.g., a bilinear function parameterized by W12.

We make the loss between the views symmetric:

L(V1, V2) = LContrastV1V2 + LContrastV2V1

we use the representation z1, z2, or the concatenation of both, [z1,z2], depending on our paradigm

An example of a critic

The critic takes images from two views: L space and AB space. It has fθ1 = l_to_ab and fθ2 = ab_to_l. I find these names misleading since self.l_to_ab does not map l to ab. It maps l to a vector of dimension = feat_dim. The same applies to ab_to_l.

Source

class alexnet(nn.Module):
    def __init__(self, feat_dim=128):
        super(alexnet, self).__init__()

        self.l_to_ab = alexnet_half(in_channel=1, feat_dim=feat_dim)
        self.ab_to_l = alexnet_half(in_channel=2, feat_dim=feat_dim)

    def forward(self, x, layer=8):
        l, ab = torch.split(x, [1, 2], dim=1)
        feat_l = self.l_to_ab(l, layer)
        feat_ab = self.ab_to_l(ab, layer)
        return feat_l, feat_ab

Connection with mutual information

It can be shown that the optical critic is proportional to the density ratio between p(z1, z2) and p(z1)p(z2).

It can also be shown that
I(z1; z2) >= log(k) - LContrast

where k is the number of negative pairs in sample set.

Hence minimizing the objective L maximizes the lower bound on the mutual information I(z1; z2), which is bounded above by I(v1; v2) by the data processing inequality. The dependency on k also suggests that using more negative samples can lead to an improved representation; we show that this is indeed the case.


Contrastive learning with more than two views

There are two ways to do so:

Core graph view

Given M views V1, , VM, we can choose to optimize over one view only. What this means is that the model will learn best how to learn representations of image in that particular view.

If we want to optimize over the first view, the loss function is defined as:

L(V1) = Σj L(V1, Vj)

A more general equation is:

L(Vi) = Σj L(Vi, Vj)

Full graph view

Here you optimize over all views by choosing all possible (i,j) pairs for creating a loss function. There are MC2 ways to do so.

Both these formulations have the effect that information is prioritized in proportion to the numberof views that share that information. This can be seen in the information diagrams visualized below:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

This in the core graph view, the mutual information between say V2 and V3 is discarded but not in the case of full graph view.

Under both the core view and full graph objectives, a factor,like “presence of dog”, that is common to all views will be preferred over a factor that affects fewerviews, such as “depth sensor noise”.


Approximating the softmax distribution with noise-contrastive estimation

Let's revisit the function below:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

If k in the above formula is large, computing the full softmax loss will be expensive.

This problem is solved by using noise contrastive estimation trick. Assume that the m negative samples are distributed uniformly i.e. pn is uniform. Then we have:

P(D=1|v2; v1i ) = pd(v2 | v1i) / [ pd(v2 | v1i) + m*pn(v2 | v1i) ]

The distribution of positive samples pd is unknown. pd is approximated by an unnormalized density hθ(.).

In this paper, hθ(v1i, v2i) = <v1i, v2i> where <.,.> stands for dot product.


Implementation of the loss function

This section mostly explains the implementation of the NCE Loss here.

This file simply uses a simple trick to allow faster sampling. I won't go into the details here since it is not relevant to the central idea of the paper. In short it creates a class called AliasMethod which is used in lieu of multinomial sampling:

self.multinomial = AliasMethod(self.unigrams)

The implementation converts RGB image into LAB and splits it into two views: L and AB.

Storing representations in memory bank

We maintain a memory bank to store latent features for each training sample. Therefore, we can efficiently retrieve m noise samples from the memory bank to pair with each positive sample without recomputing their features. The memory bank is dynamically updated with features computed on the fly.

These representations are stored in the same file that calculates the NCELoss: NCEAverage.py. This is done using the register_buffer property of PyTorch nn.modules:

self.register_buffer('memory_l', torch.rand(outputSize, inputSize)
self.register_buffer('memory_ab', torch.rand(outputSize, inputSize)

Here outputSize is the size of the dataset and inputSize is the size of representations (128 in case of Alexnet).

We want these representations to have a unit size on average. The way these are initialized is by uniform sampling from the interval [-a,a] such that the expected value of L2 norm of vector with size inputSize is 1. In other words,

Σi E[xi2] = 1
which means
Σi Var[xi] + (E[xi])2 = 1

Solving this gives us:
a = 1. / math.sqrt(inputSize / 3).

Thus the actual initilization of memory bank looks like:

stdv = 1. / math.sqrt(inputSize / 3)
self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))

The following code below then does the following:

  1. randomly sample negative samples and get the values from the memory bank
  2. copy the values of the positive samples in the first index
  3. calculate dot product using batch matrix multiplication (bmm)

# score computation
if idx is None:
    idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
    idx.select(1, 0).copy_(y.data)
# sample
weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach()
weight_l = weight_l.view(batchSize, K + 1, inputSize)
out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1))
# sample
weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach()
weight_ab = weight_ab.view(batchSize, K + 1, inputSize)
out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1))

Finally the memory bank is updated using momentum here.

l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
l_pos.mul_(momentum)
l_pos.add_(torch.mul(l, 1 - momentum))
l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
updated_l = l_pos.div(l_norm)
self.memory_l.index_copy_(0, y, updated_l)
tags: self-supervised-learning