# MLP-Mixer: An All-MLP Architecture for Vision
- Mostly the same people behind ViT paper.
- Adequate (84.15 top 1 on ImageNet by Mixer-L/16) but not SOTA. Benefits much more from scaling up.
- Common part with ViT: Divide an image into NxN patches, unroll each patch and do a linear transform.
- Some simple `Mlp = Linear -> Activation -> Dropout -> Linear -> Dropout` style MLP layers implemented [here](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/layers/mlp.py)
```python
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
```
- The core layer [implementation](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/mlp_mixer.py#L146) is pretty straight-forward:
```python
class MixerBlock(nn.Module):
""" Residual Block w/ token mixing and channel MLPs
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
def __init__(
self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
self.norm1 = norm_layer(dim)
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
```
- Note how `mlp_tokens` and `mlp_channels` operate along two dimensions of the (tabular) data.
- `Droppath` allows for stochastic depth, implemented [here](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/layers/drop.py#L160)
TODO: Understand `SpatialGatingUnit` and `SpatialGatingBlock` defined [here](https://github.com/rwightman/pytorch-image-models/blob/85c5ff26d741b2de29d990e0637b06014ec8ad15/timm/models/mlp_mixer.py#L201-L241).
###### tags: `vit`