Code samples are taken from here and here.
The SimCLR framework has four major components:
Taken from here:
f(·)
that extracts representation vectors from augmented data examples:['conv1',
'bn1',
'relu',
'maxpool',
'layer1',
'layer2',
'layer3',
'layer4',
'avgpool',
'fc']
Linear(in_features=512, out_features=50, bias=True)
torch.Size([4, 50])
g(·)
that mapsrepresentations to the space where contrastive loss is appliedWe use a MLP with one hidden layer to obtain:
We find it beneficial to define the contrastive loss on rather than
We randomly sample a minibatch of
N
examples and define the contrastive prediction task on pairs of augmented examples derived from the minibatch, resulting in2N
data points. We do not sample negative examples explicitly. Instead, given a positive pair, we treatthe other2(N−1)
augmented examples within a minibatch as negative examples.
No wonder you need such a huge batch size to train.
To keep it simple, we do not train the model with a memory bank. Instead, we vary the training batch size
N
from256
to8192
. A batch size of8192
gives us16382
negative examples per positive pair from both augmentation views.
Define as:
Then the loss is defined as:
where adjacent images at indices 2k
and 2k-1
are augmentations of the same image.
Training with large batch size may be unstable when using standard SGD/Momentum with linear learning rate scaling. To stabilize the training, we use the LARS optimizer for all batch sizes. We train our model with CloudTPUs, using 32 to 128 cores depending on the batch size.
We conjecture that one serious issue when using only random cropping as data augmentation is that most patches from an image share a similar color distribution. Figure 6 shows that color histograms alone suffice to distinguish images. Neural nets may exploit this shortcut to solve the predictive task. Therefore, it is critical to compose cropping with color distortionin order to learn generalizable features.
A nonlinear projection head improves the representation quality of the layer before it
self-supervised-learning