# DeepInfoMax - [Arxiv link](https://arxiv.org/abs/1808.06670) As [mentioned here](https://www.microsoft.com/en-us/research/blog/deep-infomax-learning-good-representations-through-mutual-information-maximization/): > DIM is based on two learning principles: mutual information maximization in the vein of the infomax optimization principle and self-supervision, an important unsupervised learning method that relies on intrinsic properties of the data to provide its own annotation. ```python import torch import torch.nn as nn import torch.nn.functional as F ``` ```python class Encoder(nn.Module): def __init__(self): super().__init__() self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1) self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1) self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1) self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1) self.l1 = nn.Linear(512*20*20, 64) self.b1 = nn.BatchNorm2d(128) self.b2 = nn.BatchNorm2d(256) self.b3 = nn.BatchNorm2d(512) def forward(self, x): h = F.relu(self.c0(x)) features = F.relu(self.b1(self.c1(h))) h = F.relu(self.b2(self.c2(features))) h = F.relu(self.b3(self.c3(h))) encoded = self.l1(h.view(x.shape[0], -1)) return encoded, features ``` ```python class GlobalDiscriminator(nn.Module): def __init__(self): super().__init__() self.c0 = nn.Conv2d(128, 64, kernel_size=3) self.c1 = nn.Conv2d(64, 32, kernel_size=3) self.l0 = nn.Linear(32 * 22 * 22 + 64, 512) self.l1 = nn.Linear(512, 512) self.l2 = nn.Linear(512, 1) def forward(self, y, M): h = F.relu(self.c0(M)) h = self.c1(h) h = h.view(y.shape[0], -1) h = torch.cat((y, h), dim=1) h = F.relu(self.l0(h)) h = F.relu(self.l1(h)) return self.l2(h) ``` ```python class LocalDiscriminator(nn.Module): def __init__(self): super().__init__() self.c0 = nn.Conv2d(192, 512, kernel_size=1) self.c1 = nn.Conv2d(512, 512, kernel_size=1) self.c2 = nn.Conv2d(512, 1, kernel_size=1) def forward(self, x): h = F.relu(self.c0(x)) h = F.relu(self.c1(h)) return self.c2(h) ``` ```python class PriorDiscriminator(nn.Module): def __init__(self): super().__init__() self.l0 = nn.Linear(64, 1000) self.l1 = nn.Linear(1000, 200) self.l2 = nn.Linear(200, 1) def forward(self, x): h = F.relu(self.l0(x)) h = F.relu(self.l1(h)) return torch.sigmoid(self.l2(h)) ``` ```python class Classifier(nn.Module): def __init__(self): super().__init__() self.l1 = nn.Linear(64, 15) self.bn1 = nn.BatchNorm1d(15) self.l2 = nn.Linear(15, 10) self.bn2 = nn.BatchNorm1d(10) self.l3 = nn.Linear(10, 10) self.bn3 = nn.BatchNorm1d(10) def forward(self, x): encoded, _ = x[0], x[1] clazz = F.relu(self.bn1(self.l1(encoded))) clazz = F.relu(self.bn2(self.l2(clazz))) clazz = F.softmax(self.bn3(self.l3(clazz)), dim=1) return clazz ``` ### Encoder takes as input an image of size `(32,32)` and returns features of size `(128,26,26)` and encoded output of length `(64)` ```python x = torch.randn(1,3,32,32) with torch.no_grad(): encoded, features = Encoder()(x) encoded.size(), features.size() ``` (torch.Size([1, 64]), torch.Size([1, 128, 26, 26])) ### Global discriminator learns to discriminate whether or not `encoded` and `features` come from the same image. Note that `encoded` and `features` are concatenated in the linear layer. ```python with torch.no_grad(): out = GlobalDiscriminator()(encoded, features) out.size() ``` torch.Size([1, 1]) ### Local discriminator does the same thing but for each individual cell. Note that `encoded` and `features` are concatenated at the start at the convolutional layer. ```python encoded_expanded = encoded.unsqueeze(2).unsqueeze(3).expand(-1,-1,26,26) x = torch.cat((features, encoded_expanded), dim=1) with torch.no_grad(): out = LocalDiscriminator()(x) out.size() ``` torch.Size([1, 1, 26, 26]) ### Prior discriminator simply learns to predict whether or not `encoded` comes from a uniform distribution ```python with torch.no_grad(): out = PriorDiscriminator()(encoded) out.size() ``` torch.Size([1, 1]) In the [official implementation](https://github.com/DuaneNielsen/DeepInfomaxPytorch/blob/master/train.py#L25-L49), * `y` is `encoded` * `M` is `features` * `M_prime` is `features` from another image The objective is to maximize the log likelihood for (`y`, `M`) and minimize that for (`y`, `M_prime`). ### How is `y_prime` created? In every batch, the sequence of the images [is changed](https://github.com/DuaneNielsen/DeepInfomaxPytorch/blob/master/train.py#L91-L92) (in a non-random way). This is different from the random way the sequence was changed in MINE. ```python y, M = encoder(x) # rotate images to create pairs for comparison M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0) ``` ###### tags: `self-supervised-learning`