As mentioned here:
DIM is based on two learning principles: mutual information maximization in the vein of the infomax optimization principle and self-supervision, an important unsupervised learning method that relies on intrinsic properties of the data to provide its own annotation.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1)
self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1)
self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1)
self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
self.l1 = nn.Linear(512*20*20, 64)
self.b1 = nn.BatchNorm2d(128)
self.b2 = nn.BatchNorm2d(256)
self.b3 = nn.BatchNorm2d(512)
def forward(self, x):
h = F.relu(self.c0(x))
features = F.relu(self.b1(self.c1(h)))
h = F.relu(self.b2(self.c2(features)))
h = F.relu(self.b3(self.c3(h)))
encoded = self.l1(h.view(x.shape[0], -1))
return encoded, features
class GlobalDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(128, 64, kernel_size=3)
self.c1 = nn.Conv2d(64, 32, kernel_size=3)
self.l0 = nn.Linear(32 * 22 * 22 + 64, 512)
self.l1 = nn.Linear(512, 512)
self.l2 = nn.Linear(512, 1)
def forward(self, y, M):
h = F.relu(self.c0(M))
h = self.c1(h)
h = h.view(y.shape[0], -1)
h = torch.cat((y, h), dim=1)
h = F.relu(self.l0(h))
h = F.relu(self.l1(h))
return self.l2(h)
class LocalDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.c0 = nn.Conv2d(192, 512, kernel_size=1)
self.c1 = nn.Conv2d(512, 512, kernel_size=1)
self.c2 = nn.Conv2d(512, 1, kernel_size=1)
def forward(self, x):
h = F.relu(self.c0(x))
h = F.relu(self.c1(h))
return self.c2(h)
class PriorDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.l0 = nn.Linear(64, 1000)
self.l1 = nn.Linear(1000, 200)
self.l2 = nn.Linear(200, 1)
def forward(self, x):
h = F.relu(self.l0(x))
h = F.relu(self.l1(h))
return torch.sigmoid(self.l2(h))
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(64, 15)
self.bn1 = nn.BatchNorm1d(15)
self.l2 = nn.Linear(15, 10)
self.bn2 = nn.BatchNorm1d(10)
self.l3 = nn.Linear(10, 10)
self.bn3 = nn.BatchNorm1d(10)
def forward(self, x):
encoded, _ = x[0], x[1]
clazz = F.relu(self.bn1(self.l1(encoded)))
clazz = F.relu(self.bn2(self.l2(clazz)))
clazz = F.softmax(self.bn3(self.l3(clazz)), dim=1)
return clazz
(32,32)
and returns features of size (128,26,26)
and encoded output of length (64)
x = torch.randn(1,3,32,32)
with torch.no_grad():
encoded, features = Encoder()(x)
encoded.size(), features.size()
(torch.Size([1, 64]), torch.Size([1, 128, 26, 26]))
encoded
and features
come from the same image. Note that encoded
and features
are concatenated in the linear layer.with torch.no_grad():
out = GlobalDiscriminator()(encoded, features)
out.size()
torch.Size([1, 1])
encoded
and features
are concatenated at the start at the convolutional layer.encoded_expanded = encoded.unsqueeze(2).unsqueeze(3).expand(-1,-1,26,26)
x = torch.cat((features, encoded_expanded), dim=1)
with torch.no_grad():
out = LocalDiscriminator()(x)
out.size()
torch.Size([1, 1, 26, 26])
encoded
comes from a uniform distributionwith torch.no_grad():
out = PriorDiscriminator()(encoded)
out.size()
torch.Size([1, 1])
In the official implementation,
y
is encoded
M
is features
M_prime
is features
from another imageThe objective is to maximize the log likelihood for (y
, M
) and minimize that for (y
, M_prime
).
y_prime
created?In every batch, the sequence of the images is changed (in a non-random way). This is different from the random way the sequence was changed in MINE.
y, M = encoder(x)
# rotate images to create pairs for comparison
M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
self-supervised-learning