owned this note
owned this note
Published
Linked with GitHub
# [Transformer_CV] Masked Autoencoders(MAE)論文筆記
###### tags: `Literature Reading` `Vision Transformer` `Paper`
### [AI / ML領域相關學習筆記入口頁面](https://hackmd.io/@YungHuiHsu/BySsb5dfp)
---
## 論文概覽
#### [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)

- [官方Github](https://github.com/facebookresearch/mae)
##### Encoder架構為Vision Transformer(ViT)
###### 原始論文:[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)

- 見[Vision Transformer(ViT)重點筆記](https://hackmd.io/@YungHuiHsu/ByDHdxBS5)
:::success
- 在NLP領域中,基於掩蔽自編碼(Masked Autoencoder)的自監督預訓練取得巨大的成功(BERT),而掩蔽自編碼器是一種更通用的去噪自編碼器(Denoised Autoencoder),也適用於計算機視覺,盡管如此,視覺自編碼方法的發展為何還是落後於 NLP?
- 本篇論文從以下幾點回答:
1、CV 和 NLP 主流架構不同:
- 直到 ViT (2020.12) 出現之前,CV 的主流架構一直是以卷積網路(CNN)為主,NLP 的主流架構一直是以 Transformer 為主。CNN是以一個固定窗口(filter)滑過圖片,本質上與Transformer所產生token不同
2、語言和圖片 (視頻) 的信息密度不同:
- 語言是經過人類大腦組織過、高訊息密度的訊息,是富含高度語意資訊的。而圖像或影片是自然產生的訊息,在空間上是高度冗餘的。
- 例如,擋住圖片的一部分,可以很容易地通過從周圍而想像出它的樣子來,但如果遮住一個句子中的單詞,則較難回推。
3、 Decoder部分在 CV 和 NLP 中充當的角色不同:
- 在 CV 領域,Decoder 的作用是重建影像,Decoder輸出的語意(訊息)級別是低階的。
- 在 NLP 領域,Decoder 的作用是重建單詞,Decoder輸出的是高階、富含資訊的語義級別
- 為了克服這種差异並鼓勵學習有用的特征,論文中展示了:一個簡單的策略在計算機視覺中也能非常有效:掩蔽很大一部分隨機 patch。這種策略在很大程度上减少了冗餘,並創造了一個具有挑戰性的自監督任務,該任務需要超越低級圖像統計的整體理解。
:::
### 模型架構

- 訓練階段
- 在預訓練期間,圖像大部分區域(例如75%)被掩蓋掉。
- 編碼階段: 以一個較大的編碼器(Encoder)對未被遮蔽的圖塊(patch)進行編碼
- 解碼階段: 依據位置訊息,將全部的圖塊(包含被遮蔽的),以一個小的解碼器(decoder)來處理,以像素重建原始圖像。
- 損失函數:
- 目標是原始與生成圖像間image pixels層級的均方誤差(Mean squared error)最小化
- Fine-tune(針對下游特定任務的遷移學習)
- 經過預訓練。解碼器被丟棄,編碼器被應用於未被破壞的圖像(所有圖塊)進行識別等其他下游任務。
> MAE 方法嚴格來講屬於一種去噪自編碼器 (Denoising Auto-Encoders (DAE)),去噪自動編碼器是一類自動編碼器,它破壞輸入信號,並學會重構原始的、未被破壞的信號。 MAE 的 Encoder 和 Decoder 結構不同,是非對稱式的。 Encoder 將輸入編碼為 latent representation,而 Decoder 將從 latent representation 重建原始信號。
>
> MAE 和 ViT 的做法一致,將圖像劃分成規則的,不重疊的 patches。然後按照均勻分佈不重複地選擇一些 patches 並且 mask 掉剩餘的 patches。作者採用的 mask ratio 足夠高,因此大大減小了 patches 的冗餘信息,使得在這種情況下重建 images 不那麼容易。
### 參考資料:
- [大道至簡,何愷明新論文火了:Masked Autoencoders讓計算機視覺通向大模型](https://cdmana.com/2021/11/20211113143300494Z.html)
- [Self-Supervised Learning 超详细解读 (六):MAE:通向CV大模型](https://zhuanlan.zhihu.com/p/432950958)
### 學習資源
- 影片
- [李沐_MAE 论文逐段精读【论文精读】](https://www.youtube.com/watch?v=mYlX2dpdHHM&list=PLFXJ6jwg0qW-7UM8iUTj3qKqdhbQULP5I&index=10)
- [Masked Autoencoders Are Scalable Vision Learners – Paper explained and animated!](https://www.youtube.com/watch?v=Dp6iICL2dVI)
## 核心程式碼筆記
- [mae/models_mae.py](https://github.com/facebookresearch/mae/blob/main/models_mae.py)
### 初始化與整體前向傳播流程
- norm_layer(embed_dim) 即nn.LayerNorm()為將整層進行標準化
- 相對於BatchNorm為將整個批次的資料標準化
- 標準化的對象為被切成多個1維圖塊(序列)的影像,標準化的單位為每個1維扁平化後的圖塊
- output: return `loss`, `pred`, `mask`
- 這裡的pred是一個由多個1維圖塊組成的sequence
- shape : (N, L, embed_dim)
- N是batch size,L是patch數量,embed_dim是指定的每個圖塊token的內嵌向量維度
- 預測結果需要去區塊化還原成影像
- `model.unpatchify(pred)`
::: spoiler code
```python=
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from util.pos_embed import get_2d_sincos_pos_embed
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
super().__init__()
# --- skip ---
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
```
:::
### MAE的功能函式
#### 將圖像區塊化 & 去區塊化
- imgs: [N, 3, H, W]
- patch_embedding: [N, L, p x p x c]
- 轉換2維影像成為L個1維序列
- 每個序列長度: p x p x c
- eg: patch size = 16 x 16。 channel=3。
每個一維的序列長度: 16 x 16 x 3 = 768
- ==L (num_patches)== = img_h x img_w / p x p
- 當輸入影像長寬為 224 x 224, Patch size = 16x16
則有 (224 x 224) / (16 x 16) = 14 x 14 = 196個patch
- 相當於NLP(自然語言處理)領域的sentence_length
:::spoiler code
```python=95
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return
```
:::
#### 隨機遮蔽圖像
==random_masking()==
- 將圖像轉換成L(patch number)個1為的序列後,隨機抽樣決定那些圖塊(1d patch)要被遮蔽
- input:
- imgs
- imgs (N, 3, H, W)-> x (N, L, dim)
- dim : patch_size**2 * c
- N是batch
- output: return `x_masked`, `mask`, `ids_restore`
- `x_masked `
- 沒被遮蔽的1d圖塊(unmasked 1d patch)
- 透過ids_keep內的指標,取得欲保留(未被遮蔽)的path
- shape : (N, L*(1-maske_rario), dim)
- `mask`
- shape : (N, L),值為[0, 1]的遮罩。0 is keep, 1 is remove
- N是batch size,~~L是patch數量 * 被遮蔽的比例~~
- `torch.gather(mask, dim=1, index=ids_restore)` 根據`ids_restore`指標位置對mask重新取值/排序
- `ids_restore`
- 取得由小到大的排序index,作為圖塊遮蔽(encoding)與還原時(decoding)的位置指標
- `ids_shuffle`與`ids_restore` 取得的index值其實是一樣的
- shape : (N, L)
- 沒被遮蔽的圖塊(函式中的x_masked)在def forward_encoder()中,會加上cls token送入Transformer blocks中訓練
::: spoiler code
```python=123
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
```
:::
### ==編碼器(encoder)==
```python=
def forward_encoder():
return return latent, mask, ids_restore
```
#### Encoder(ViT)內的Embedding
##### 類別token
==cls_tokens==
- shape of : (1, 1, embed_dim) > (N, L, embed_dim)
- 這裡內容值均為0,由於沒有對應的分類任務訓練,因此MAE中的cls_tokens實際上並沒有學習到實質內容,此階段保留僅為符合ViT模型架構
- 在後續下游任務遷移學習時,仍可透過訓練學習到對應任務所需的權重
##### 位置編碼
==pos_embed==
- shape : (batch, num_patches+1, embed_dim)
- ~~pos_embed.shape[-1] = embed_dim = 1024~~
- int(patch_embed.num_patches**.5) = sqrt(196)=14
::: spoiler code
```python=134
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
```
:::
#### 初始化建立編碼器(encoder)
- 位置編碼中的`self.pos_embed` 中第二維度的num_patches + 1, 這個1是ViT架構中的類別編碼,加在最前面
:::spoiler code MAE encoder
```python=32
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
```
:::
#### 編碼器前向傳播流程
- 流程:
1. `x = self.patch_embed(x)`
- 將圖像切割為依照 p * p的大小切分並線性轉換為L個序列(1維圖塊)
- shape of patch_embed(x) : (batch, num_patches, embed_dim)
- ~~每個一維圖塊長度是 patch_size**2 * c~~
- 與patchify(imgs)的形狀相同
2. `x = x + self.pos_embed[:, 1:, :]`
- 加上(移除類別項目的)位置編碼
- pos_embed[:, 0, :]是類別的編碼,因此從1開始取值為捨棄類別編碼
3. ==進行隨機遮蔽`self.random_masking(x, mask_ratio)`==
- code中為了進行隨機遮蔽的操作,於前後分別移除與加回位置編碼的類別項,主要為了便於隨機遮蔽的操作
- 這邊使用的ViTransformer核心為用於分類任務的架構,前面自帶類別項目
- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
- 這是MAE在encoder部分的精妙處,其他部分操作核心與ViT一致
4. `x = torch.cat((cls_tokens, x), dim=1)`
- 把類別編碼(cls_tokens)加回==未被遮蔽==的1維圖塊
- cls_tokens 關於MAE內的類別token
- shape of cls_tokens : (1, 1, embed_dim) > (N, L, embed_dim)
5. 經過上述處理後把帶有類別與位置編碼訊息的1維圖塊們送入Transformer模型(self attention)內
- shape of x : [batch, num_patches + 1, embed_dim]
- 這邊的blocks用的是timm library寫好的self-attention模組
- ` Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])`
```python=165
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
```
:::spoiler forward_encoder()
```python=15
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
```
:::
### ==解碼器(decoder)==
- 傳入參數為x, ids_restore
- x
- ids_restore
- 圖塊遮蔽(encoding)與還原時(decoding)的位置指標
```python=
def forward_decoder(self, x, ids_restore):
...
# remove cls token
x = x[:, 1:, :]
return x
```
#### 初始化建立解碼器(decoder)
::: spoiler MAE decoder specifics
```python=46
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
```
:::
#### 解碼器前向傳播
- decoder與encoder不對稱,深度與寬度都較小
- 解碼前先把mask tokens加回sequence,亦即解碼回去圖像時的input是包含MASK圖塊
::: spoiler forward_decoder()
```python=172
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
```
:::
### 損失函數計算與前向傳播
==forward_loss()==
- MSE Loss
- 先將原始影像區塊後,計算每個patch的loss後取平均
- norm_pix_loss : 將原始影像以patch為單位做標準化
- 再取出被遮蔽的patch,計算被遮蔽patch部分的loss
::: spoiler code
```python=
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
```
:::
#### 參考資料
- pytorch函式
- gather()
> 对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index
> - 定义:从原tensor中获取指定dim和指定index的数据
> 从完整数据中按索引取值
> - 用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的
- 引用自:
- [torch.gather() 和torch.sactter_()的用法简析](https://blog.csdn.net/Teeyohuang/article/details/82186666)
- [图解PyTorch中的torch.gather函数](https://zhuanlan.zhihu.com/p/352877584)
- repeat()和 expand()
- expand
> 函数对返回的张量不会分配新内存,即在原始张量上返回只读视图,返回的张量内存是不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。
> 扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。
- repeat
> torch.repeat用法类似np.tile,就是将原矩阵横向、纵向地复制。与torch.expand不同的是torch.repeat返回的张量在内存中是连续的。
> 沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据
- 引用自:
- [PyTorch学习笔记——repeat()和expand()区别](https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html)
- [pytorch torch.expand和torch.repeat的区别详解](https://www.jb51.net/article/173582.htm)
## Deep Learning相關筆記
### Self-supervised Learning
- [[Self-supervised] Self-supervised Learning 與 Vision Transformer重點筆記與近期發展](https://hackmd.io/7t35ALztT56STzItxo3UiA)
- [[Time Series] - TS2Vec(Towards Universal Representation of Time Series) 論文筆記](https://hackmd.io/OE9u1T9ETbSdiSzM1eMkqA)
### Object Detection
- [[Object Detection_YOLO] YOLOv7 論文筆記](https://hackmd.io/xhLeIsoSToW0jL61QRWDcQ)
### ViT與Transformer相關
- [[Transformer_CV] Vision Transformer(ViT)重點筆記](https://hackmd.io/tMw0oZM6T860zHJ2jkmLAA)
- [[Transformer] Self-Attention與Transformer](https://hackmd.io/fmJx3K4ySAO-zA0GEr0Clw)
- [[Explainable AI] Transformer Interpretability Beyond Attention Visualization。Transformer可解釋性與視覺化](https://hackmd.io/SdKCrj2RTySHxLevJkIrZQ)
- [[Transformer_CV] Masked Autoencoders(MAE)論文筆記](https://hackmd.io/lTqNcOmQQLiwzkAwVySh8Q)
### Autoencoder相關
- [[Autoencoder] Variational Sparse Coding (VSC)論文筆記](https://hackmd.io/MXxa8zesRhym4ahu7OJEfQ)
- [[Transformer_CV] Masked Autoencoders(MAE)論文筆記](https://hackmd.io/lTqNcOmQQLiwzkAwVySh8Q)