MLP is all you need in vision?
===
Author: Willie Chen

[TOC]
# 一、 前言
   大家都知道深度學習在CV領域的發展過程為 ==MLP -> CNN -> Transformer==,尤其是Transformer,在近幾年幾乎是血洗了各個領域,但在最近可能又繞回去了原點到"MLP"。
   在**2021年5月5日**,Google Brain團隊在arxiv上丟了一篇新論文,題目為 "**MLP-Mixer: An all-MLP Architecture for Vision**",顧名思義,就是為了電腦視覺領域設計一個全部都是MLP的架構,本文會將重點放在這篇論文的細節與內容,以及會附上我自己implemtn的PyTorch程式碼。
下面是論文以及連結:

論文網址: https://arxiv.org/pdf/2105.01601v1.pdf
---
   但有一個不信的消息就是,有人的論文剛好跟Google Brain這篇撞車,作者是 **Luke Melas-Kyriazi**,如果有人在Github上用過Efficient的PyTorch版本 [[Efficient PyTorch on Github]](https://github.com/lukemelas/EfficientNet-PyTorch),應該對這位大大不陌生,因為就是同一位。就在Google Brain發表**MLP-Mixer**的隔一天(**2021年5月6日**),該篇論文題目為 "**Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet**",稍稍看了一下論文內容,不管是思路或網路架構都跟MLP-Mixer幾乎一樣,真的是...希望作者節哀順變,附上作者在Reddit的討論串 [[Reddit]](https://www.reddit.com/r/MachineLearning/comments/n62qhn/r_do_you_even_need_attention_a_stack_of/)。
下面是論文以及連結:

論文網址: https://arxiv.org/pdf/2105.02723.pdf
---
### 兩篇論文的架構圖
MLP-Mixer | Do you need attention
:-------------------------:|:-------------------------:
 | 
   從架構圖可以很清楚地發現,這兩篇都是將原始圖片切成N個Patches,再將N個Patches丟入一個全是MLP架構的網路,網路中間還會對Patches進行Transpose的操作,這是為了要模擬出跟CNN差不多的概念,但不使用Conv Layer,最後一層則是Prediction的Layer。
   **總結來說就是提出了一個全MLP的網路架構"MLP-Mixer",完全不需要Conv Layer和Attention mechanism,就可以達到跟CNN和Transformer相媲美的分類型能**,文章後面會更詳細的介紹論文內容。
---
# 二、 MLP-Mixer的整體架構
MLP-Mixer完全不依賴CNN以及self-attention的機制(~~但為了將Parches轉成Embedding,還是有套一層Conv Layer~~),而是使用MLP去萃取圖片的空間位置和特徵通道的資訊,而且沒有使用甚麼複雜的操作,僅僅使用Reshape、Transposition和一系列的矩陣乘法操作。
<br>
**下圖是MLP-Mixer的整體架構**

   一開始的操作跟Vision Transformer相同,就是將圖片拆成N個patches(Shape: patches x channel),這邊其實就是對patch進行Embedding的動作。再將這些Embedding通過Mixer Layer,Mixer利用兩種類型的MLP層:**token-mixing MLPs**和**channel-mixing MLPs**,同時也是本論文的核心所在。
## Token & Channel Mixing
* 1. **token-mixing MLPs**
> **主要用於不同空間位置(論文稱之為: token,這邊指的就是Patches)上的communication,在每個Channel上獨立運行,並以每一個column做為輸入。**
                    
                              **Token-Mixing MLPs的示意圖**
* 2. **channel-mixing MLPs**
> **主要用於不同Channel之間的communication,在每個Token上獨立運行,並以每一個row作為輸入。**
                     
                              **Channel-Mixing MLPs的示意圖**
以上這兩種類型的MLP層是互相交錯的,以實現兩個輸入維度(Patches & Channel)的交互操作。
<br>
## Mixing 初步講解(與CNN的關聯性)
   在極端的例子下,這種架構可以視為一種特別的CNN,<font color="red">**token-mixing MLP可視為full receptive field的single-channel depth-wise convolutions**</font>並且做到了權重共享,<font color="red">**channel-mixing MLPs可視為1x1 convolutions**</font>。然而,作者表示標準的CNN並不是Mixer的special case。此外,Convolution比Mixer中的普通矩陣乘法更為複雜,因為Convolution還需要對矩陣乘法進行專門的implementation讓計算成本降低。
   儘管它的架構很簡單,MLP-Mixer仍然取得了頗具競爭力的結果。尤其是pre-trained在大型的資料集上(Ex: 1億張圖片),就準確率跟訓練成本來說,達到了先前CNN和Transformer的SOTA結果,**87.94% Top-1 Validation Accuracy on ImageNet**,87.94%的結果還不是訓練在ImageNet上,而是使用了Google自己內部的JFT-300M資料集來作預訓練再去做Transfer Learning得到的結果,JFT-300M 顧名思義就是有3億張的Labeld Image,看到這邊可以發現我們根本沒有那麼大量的資料阿~,就算是大家常用的ImageNet也只不過才1400多萬張,不過預訓練在Imagenet-21k上的結果也相當不錯,Top-1 Validation Accuracy達到了84.15%,這成績就已經比大家常用的ResNet的結果還要來的好了。
---
# 三、 Mixer架構
## CNN、Transformer、Mixer彼此的特性
在現今的CNN架構使用了三種方式來進行混合特徵:
* **(i) 在給定的空間位置**
* **(ii) 在不同的空間位置之間**
* **(iii) 將(i)和(ii)相結合**
   在一般的CNN裡,(ii)是透過NxN的convolution(N > 1)以及pooling layer所實現,較深層的神經元具有較大的receptive field。在同一時間,1×1 convolution也執行(i),較大的kernel則同時執行(i)和(ii)。
   在Vision Transformer或是其他self-attention的架構裡,self-attenion layer則是同時包含了(i)和(ii),而輸出的MLP則是執行(i)的部分。 Mixer架構的背後idea是清楚的將 **per-location(Channel mixing)** 的操作和 **cross-location(Token mixing)** 的操作分開來,這兩者全用MLP來實現。

<br>
## Embedding操作
   Mixer將輸入圖片切割成S個沒有重疊的圖片當作輸入(Patches也可以稱為Token),再將每個Patch透過Embedding映射到一個隱藏維度(Hidden dimension),$可以得到 X ∈ R^{SxC}$。 假設輸入的原始影像為 $(H \times W)$,而且每張Patch的解析度為$(P \times P)$,就可以得到 $S = {HW \over P^{2}}$。 Ex: 輸入圖片維度是$256 \times 256$、Patches維度為$16 \times 16$,則可以得到 $S = 256$。所有的Patches都使用相同的projection matrix來進行投影(實作時是透過一層的Conv Layer來進行此操作)。
   Mixer由大小相同的多層Mixer block所組成,每一個Mixer Block都包含兩組的MLP block。
* 第一組是token-mixing MLP block: 它作用於X中的Column,在此操作前要先進行Reshape的操作($batches \times num\_patches \times channel \rightarrow batches \times channel \times num\_patches$),在這邊所有column之間的權重是共享的。
* 第二組是channel-mixing MLP block: 它作用於X中的Row,在此操作前要先進行Reshape的操作($batches \times channel \times num \_ patches \rightarrow batches \times num\_patches \times channel$)。
<br>
## MLP Block
   每個MLP Block都包含兩個fully-connected layers和一個獨立於輸入的nonlinearity layer,Mixer layers可以寫成下面的數學式:

   這邊的$\sigma$是GELU activation function,由於有Skip-connection的設計,因此會將輸入的Tensor透過Element-wise給輸入相加進來。$D_S和D_C$分別為token-mixing和channel-mixing MLP的Hidden dimension,注意!! 這邊的$D_S$跟Patches的數量無關,而是由使用者自行設計的。因此,網絡的計算複雜度在輸入Patches的數量上是線性的,這跟ViT的複雜度是平方的不同。$D_C$則是獨立於Patch的size,整體的計算複雜度是跟圖片得像素數量呈現一種線性關係,類似於典型的CNN。
<br>
# 四、實驗結果
## MLP-Mixer的各種結構

<br>
## Performance比較
   由下表可以看出,Mixer的準確率已經和CNN及Transformer架構的不相上下,但重點可以看到**Throughput**的欄位,在這邊可以很明顯的發現Mixer的Throughput性能是遠高於ViT及BiT的,這對於未來要將模型部屬到產品線上是很重要的一個性能指標。

<br>
   由下面圖右可以觀察出一些結果出來,X軸為訓練集的數量多寡,Y軸為ImageNet Top-1的準確率,我們將較強大的Mixer跟當前SOTA的模型一起比較,可以看到Mixer若要達到跟SOTA模型近似的結果,給它的訓練資料要非常足夠。**MLP-Mixer在訓練資料不夠多時,準確率會下降得更加嚴重**。MLP-Mixer跟SOTA模型的準確率差異,會隨著訓練資料的增加而陸續減少。

<br>
## 權重可視化
下圖顯示Mixer預訓練在JFT-300M上,前面幾個Token-Mixing MLP權重的可視化。因為Token-Mixing允許在不同的Patch之間進行權重共享,導致學習到的權重可以套用於整張影像,

<br>
下圖為Pre-Patch Fully-Connected Layer權重的可視化,左為Mixer-B/16右為Mixer-B/32。

<br>
# 五、 實現代碼(JAX & PyTorch)
## JAX 版本
**JAX版本是直接複製論文中作者所提供的程式碼,Official Github: [models_mixer.py](https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py)**
```python=
import einops
import flax.linen as nn
import jax.numpy as jnp
class MlpBlock(nn.Module) :
mlp_dim : int
@nn.compact
def __call__ (self , x) :
y = nn.Dense ( self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
class MixerBlock(nn.Module):
tokens_mlp_dim : int
channels_mlp_dim : int
@nn . compact
def __call__ ( self , x ) :
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y , 1, 2)
y = MlpBlock(self.tokens_mlp_dim , name = 'token_mixing')(y)
y = jnp.swapaxes(y , 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock (self.channels_mlp_dim , name = 'channel_mixing')(y)
class MlpMixer(nn.Module) :
num_classes : int
num_blocks : int
patch_size : int
hidden_dim : int
tokens_mlp_dim : int
channels_mlp_dim : int
@nn.compact
def __call__(self , x) :
s = self.patch_size
x = nn.Conv(self.hidden_dim , (s , s) , strides = (s , s) , name ='stem') ( x )
x = einops.rearrange (x , ''n h w c -> n (h w) c')
for _ in range (self.num_blocks) :
rBlock (self.tokens_mlp_dim , self.channels_mlp_dim)(x)
x = nn.LayerNorm (name ='pre_head_layer_norm')(x)
x = jnp.mean (x , axis =1)
return nn.Dense(self.num_classes , name ='head',
kernel_init = nn . initializers . zeros )(x)
```
<br>
## PyTorch版本
* **Import require library**
```python=
import einops
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
```
<br>
* **Define sub-module for the MLP-Mixer**
   這邊定義了兩個Class,分別是**GlobalAveragePooling**和**MLP_Block**,其中**MLP_Block**根據paper的定義是兩個Linear layer中間是使用GeLU activation function。
```python=
class GlobalAveragePooling(nn.Module):
def __init__(self, dim = 1):
super().__init__()
self.dim = dim
def forward(self, x):
return x.mean(dim = self.dim)
class MLP_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, input_dim)
)
def forward(self, x):
return self.mlp(x)
```
<br>
* **Define the Mixer block**
   MixerBlock裡面主要定義了兩個小模組,分別是token-mixing和channel-mixing。
```python=
class MixerBlock(nn.Module):
def __init__(self, dim, num_patches, token_dim, channel_dim):
super().__init__()
self.token_mixing = nn.Sequential(
nn.LayerNorm(dim),
Rearrange("batch num_patches channel -> batch channel num_patches"),
MLP_Block(num_patches, token_dim),
Rearrange("batch channel num_patches -> batch num_patches channel")
)
self.channel_mixing = nn.Sequential(
nn.LayerNorm(dim),
MLP_Block(dim, channel_dim)
)
def forward(self, x):
x = x + self.token_mixing(x)
x = x + self.channel_mixing(x)
return x
```
* **MLP-Mixer network**
   MLP-Mixer的主體架構,共由四個部分所組成:
1. 將輸入圖片轉成Embedding
2. 將Embedding通過一系列的Mixer block
3. Gloval Average Pooling
4. 用來預測的MLP head
```python=
class MLP_Mixer(nn.Module):
def __init__(
self,
in_channel = 3,
dim = 512,
num_classes = 2,
patch_size = 16,
img_size = 256,
token_dim = 256,
channel_dim = 2048,
num_layers = 8
):
super().__init__()
self.num_patch = (img_size // patch_size) ** 2
self.img2patch = nn.Sequential(
nn.Conv2d(in_channel, dim, patch_size, patch_size),
Rearrange("b c h w -> b (h w) c")
)
mixer_layer = [
MixerBlock(dim, self.num_patch, token_dim, channel_dim)
for _ in range(num_layers)
]
self.mixer_layer = nn.Sequential(*mixer_layer)
self.global_pool = GlobalAveragePooling(dim = 1)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.img2patch(x)
x = self.mixer_layer(x)
x = self.global_pool(x)
return self.mlp_head(x)
```
* **Small test**
   簡單的測試,輸入圖片的Shape為(1,3,256,256),丟進MLP-Mixer(num_classes=2),最後輸出的Shape為(1,2)。
```python=
batch = torch.randn(1, 3, 256, 256)
model = MLP_Mixer(num_classes = 2)
model(batch).shape # (1, 2)
```
<br>
# 六、 學術圈對Mixer的討論及看法
   雖然說是全MLP的架構啦,但官方實現的程式碼在第一步patch embedding就是用Conv Layer來實現的。光這點就被Yann Lecun給吐槽了,他表示Mixer是掛羊頭賣狗肉 。

**論文作者的回復如下圖:**

**想吃瓜的可以點進[[TWITTER]](https://mobile.twitter.com/ylecun/status/1390543133474234368)這邊來看,~~說實話蠻精彩的~~。**
# 七、 結論
   作者提出Mixer的純MLP架構(其實不然是純MLP),就可以達到跟CNN、Transformer可以比較的結果。透過簡單的矩陣轉換和一系列的MLP,就可以模擬出與CNN類似的效果,從結果論來看,純MLP似乎在CV領域是可行的,而且結構相對簡單許多跟Transformer相比的化。
   但也有網友點出了Mixer不足的地方:
1. 結果始終沒有達到SOTA,而且有需要大量訓練資料的這個限制在。
2. Mixer使用了現今CNN架構常使用的Residual、LayerNorm的模組,如果去掉這些不知效果是否會驟降?
3. Mixer的擴展性沒有Transformer還要來的強,沒有Encoder-Decoder的架構。