# MLP-Mixer: An All-MLP Architecture for Vision - Mostly the same people behind ViT paper. - Adequate (84.15 top 1 on ImageNet by Mixer-L/16) but not SOTA. Benefits much more from scaling up. - Common part with ViT: Divide an image into NxN patches, unroll each patch and do a linear transform. - Some simple `Mlp = Linear -> Activation -> Dropout -> Linear -> Dropout` style MLP layers implemented [here](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/layers/mlp.py) ```python class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x ``` - The core layer [implementation](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/mlp_mixer.py#L146) is pretty straight-forward: ```python class MixerBlock(nn.Module): """ Residual Block w/ token mixing and channel MLPs Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ def __init__( self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): super().__init__() tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] self.norm1 = norm_layer(dim) self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) x = x + self.drop_path(self.mlp_channels(self.norm2(x))) return x ``` - Note how `mlp_tokens` and `mlp_channels` operate along two dimensions of the (tabular) data. - `Droppath` allows for stochastic depth, implemented [here](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/layers/drop.py#L160) TODO: Understand `SpatialGatingUnit` and `SpatialGatingBlock` defined [here](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/mlp_mixer.py#L201-L241). ###### tags: `vit`