# [Transformer_CV] Masked Autoencoders(MAE)論文筆記 ###### tags: `Literature Reading` `Vision Transformer` `Paper` ### [AI / ML領域相關學習筆記入口頁面](https://hackmd.io/@YungHuiHsu/BySsb5dfp) --- ## 論文概覽 #### [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) ![image](https://hackmd.io/_uploads/HJlKRlI0ee.png =400x) - [官方Github](https://github.com/facebookresearch/mae) ##### Encoder架構為Vision Transformer(ViT) ###### 原始論文:[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) ![](https://i.imgur.com/1PNfjqR.png =800x) - 見[Vision Transformer(ViT)重點筆記](https://hackmd.io/@YungHuiHsu/ByDHdxBS5) :::success - 在NLP領域中,基於掩蔽自編碼(Masked Autoencoder)的自監督預訓練取得巨大的成功(BERT),而掩蔽自編碼器是一種更通用的去噪自編碼器(Denoised Autoencoder),也適用於計算機視覺,盡管如此,視覺自編碼方法的發展為何還是落後於 NLP? - 本篇論文從以下幾點回答: 1、CV 和 NLP 主流架構不同: - 直到 ViT (2020.12) 出現之前,CV 的主流架構一直是以卷積網路(CNN)為主,NLP 的主流架構一直是以 Transformer 為主。CNN是以一個固定窗口(filter)滑過圖片,本質上與Transformer所產生token不同 2、語言和圖片 (視頻) 的信息密度不同: - 語言是經過人類大腦組織過、高訊息密度的訊息,是富含高度語意資訊的。而圖像或影片是自然產生的訊息,在空間上是高度冗餘的。 - 例如,擋住圖片的一部分,可以很容易地通過從周圍而想像出它的樣子來,但如果遮住一個句子中的單詞,則較難回推。 3、 Decoder部分在 CV 和 NLP 中充當的角色不同: - 在 CV 領域,Decoder 的作用是重建影像,Decoder輸出的語意(訊息)級別是低階的。 - 在 NLP 領域,Decoder 的作用是重建單詞,Decoder輸出的是高階、富含資訊的語義級別 - 為了克服這種差异並鼓勵學習有用的特征,論文中展示了:一個簡單的策略在計算機視覺中也能非常有效:掩蔽很大一部分隨機 patch。這種策略在很大程度上减少了冗餘,並創造了一個具有挑戰性的自監督任務,該任務需要超越低級圖像統計的整體理解。 ::: ### 模型架構 ![image](https://hackmd.io/_uploads/HJlKRlI0ee.png =600x) - 訓練階段 - 在預訓練期間,圖像大部分區域(例如75%)被掩蓋掉。 - 編碼階段: 以一個較大的編碼器(Encoder)對未被遮蔽的圖塊(patch)進行編碼 - 解碼階段: 依據位置訊息,將全部的圖塊(包含被遮蔽的),以一個小的解碼器(decoder)來處理,以像素重建原始圖像。 - 損失函數: - 目標是原始與生成圖像間image pixels層級的均方誤差(Mean squared error)最小化 - Fine-tune(針對下游特定任務的遷移學習) - 經過預訓練。解碼器被丟棄,編碼器被應用於未被破壞的圖像(所有圖塊)進行識別等其他下游任務。 > MAE 方法嚴格來講屬於一種去噪自編碼器 (Denoising Auto-Encoders (DAE)),去噪自動編碼器是一類自動編碼器,它破壞輸入信號,並學會重構原始的、未被破壞的信號。 MAE 的 Encoder 和 Decoder 結構不同,是非對稱式的。 Encoder 將輸入編碼為 latent representation,而 Decoder 將從 latent representation 重建原始信號。 > > MAE 和 ViT 的做法一致,將圖像劃分成規則的,不重疊的 patches。然後按照均勻分佈不重複地選擇一些 patches 並且 mask 掉剩餘的 patches。作者採用的 mask ratio 足夠高,因此大大減小了 patches 的冗餘信息,使得在這種情況下重建 images 不那麼容易。 ### 參考資料: - [大道至簡,何愷明新論文火了:Masked Autoencoders讓計算機視覺通向大模型](https://cdmana.com/2021/11/20211113143300494Z.html) - [Self-Supervised Learning 超详细解读 (六):MAE:通向CV大模型](https://zhuanlan.zhihu.com/p/432950958) ### 學習資源 - 影片 - [李沐_MAE 论文逐段精读【论文精读】](https://www.youtube.com/watch?v=mYlX2dpdHHM&list=PLFXJ6jwg0qW-7UM8iUTj3qKqdhbQULP5I&index=10) - [Masked Autoencoders Are Scalable Vision Learners – Paper explained and animated!](https://www.youtube.com/watch?v=Dp6iICL2dVI) ## 核心程式碼筆記 - [mae/models_mae.py](https://github.com/facebookresearch/mae/blob/main/models_mae.py) ### 初始化與整體前向傳播流程 - norm_layer(embed_dim) 即nn.LayerNorm()為將整層進行標準化 - 相對於BatchNorm為將整個批次的資料標準化 - 標準化的對象為被切成多個1維圖塊(序列)的影像,標準化的單位為每個1維扁平化後的圖塊 - output: return `loss`, `pred`, `mask` - 這裡的pred是一個由多個1維圖塊組成的sequence - shape : (N, L, embed_dim) - N是batch size,L是patch數量,embed_dim是指定的每個圖塊token的內嵌向量維度 - 預測結果需要去區塊化還原成影像 - `model.unpatchify(pred)` ::: spoiler code ```python= from functools import partial import torch import torch.nn as nn from timm.models.vision_transformer import PatchEmbed, Block from util.pos_embed import get_2d_sincos_pos_embed class MaskedAutoencoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): super().__init__() # --- skip --- def forward(self, imgs, mask_ratio=0.75): latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] loss = self.forward_loss(imgs, pred, mask) return loss, pred, mask ``` ::: ### MAE的功能函式 #### 將圖像區塊化 & 去區塊化 - imgs: [N, 3, H, W] - patch_embedding: [N, L, p x p x c] - 轉換2維影像成為L個1維序列 - 每個序列長度: p x p x c - eg: patch size = 16 x 16。 channel=3。 每個一維的序列長度: 16 x 16 x 3 = 768 - ==L (num_patches)== = img_h x img_w / p x p - 當輸入影像長寬為 224 x 224, Patch size = 16x16 則有 (224 x 224) / (16 x 16) = 14 x 14 = 196個patch - 相當於NLP(自然語言處理)領域的sentence_length :::spoiler code ```python=95 def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return ``` ::: #### 隨機遮蔽圖像 ==random_masking()== - 將圖像轉換成L(patch number)個1為的序列後,隨機抽樣決定那些圖塊(1d patch)要被遮蔽 - input: - imgs - imgs (N, 3, H, W)-> x (N, L, dim) - dim : patch_size**2 * c - N是batch - output: return `x_masked`, `mask`, `ids_restore` - `x_masked ` - 沒被遮蔽的1d圖塊(unmasked 1d patch) - 透過ids_keep內的指標,取得欲保留(未被遮蔽)的path - shape : (N, L*(1-maske_rario), dim) - `mask` - shape : (N, L),值為[0, 1]的遮罩。0 is keep, 1 is remove - N是batch size,~~L是patch數量 * 被遮蔽的比例~~ - `torch.gather(mask, dim=1, index=ids_restore)` 根據`ids_restore`指標位置對mask重新取值/排序 - `ids_restore` - 取得由小到大的排序index,作為圖塊遮蔽(encoding)與還原時(decoding)的位置指標 - `ids_shuffle`與`ids_restore` 取得的index值其實是一樣的 - shape : (N, L) - 沒被遮蔽的圖塊(函式中的x_masked)在def forward_encoder()中,會加上cls token送入Transformer blocks中訓練 ::: spoiler code ```python=123 def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore ``` ::: ### ==編碼器(encoder)== ```python= def forward_encoder(): return return latent, mask, ids_restore ``` #### Encoder(ViT)內的Embedding ##### 類別token ==cls_tokens== - shape of : (1, 1, embed_dim) > (N, L, embed_dim) - 這裡內容值均為0,由於沒有對應的分類任務訓練,因此MAE中的cls_tokens實際上並沒有學習到實質內容,此階段保留僅為符合ViT模型架構 - 在後續下游任務遷移學習時,仍可透過訓練學習到對應任務所需的權重 ##### 位置編碼 ==pos_embed== - shape : (batch, num_patches+1, embed_dim) - ~~pos_embed.shape[-1] = embed_dim = 1024~~ - int(patch_embed.num_patches**.5) = sqrt(196)=14 ::: spoiler code ```python=134 def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) ``` ::: #### 初始化建立編碼器(encoder) - 位置編碼中的`self.pos_embed` 中第二維度的num_patches + 1, 這個1是ViT架構中的類別編碼,加在最前面 :::spoiler code MAE encoder ```python=32 # MAE encoder specifics self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) ``` ::: #### 編碼器前向傳播流程 - 流程: 1. `x = self.patch_embed(x)` - 將圖像切割為依照 p * p的大小切分並線性轉換為L個序列(1維圖塊) - shape of patch_embed(x) : (batch, num_patches, embed_dim) - ~~每個一維圖塊長度是 patch_size**2 * c~~ - 與patchify(imgs)的形狀相同 2. `x = x + self.pos_embed[:, 1:, :]` - 加上(移除類別項目的)位置編碼 - pos_embed[:, 0, :]是類別的編碼,因此從1開始取值為捨棄類別編碼 3. ==進行隨機遮蔽`self.random_masking(x, mask_ratio)`== - code中為了進行隨機遮蔽的操作,於前後分別移除與加回位置編碼的類別項,主要為了便於隨機遮蔽的操作 - 這邊使用的ViTransformer核心為用於分類任務的架構,前面自帶類別項目 - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) - 這是MAE在encoder部分的精妙處,其他部分操作核心與ViT一致 4. `x = torch.cat((cls_tokens, x), dim=1)` - 把類別編碼(cls_tokens)加回==未被遮蔽==的1維圖塊 - cls_tokens 關於MAE內的類別token - shape of cls_tokens : (1, 1, embed_dim) > (N, L, embed_dim) 5. 經過上述處理後把帶有類別與位置編碼訊息的1維圖塊們送入Transformer模型(self attention)內 - shape of x : [batch, num_patches + 1, embed_dim] - 這邊的blocks用的是timm library寫好的self-attention模組 - ` Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(depth)])` ```python=165 # apply Transformer blocks for blk in self.blocks: x = blk(x) x = self.norm(x) ``` :::spoiler forward_encoder() ```python=15 def forward_encoder(self, x, mask_ratio): # embed patches x = self.patch_embed(x) # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] # masking: length -> length * mask_ratio x, mask, ids_restore = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for blk in self.blocks: x = blk(x) x = self.norm(x) return x, mask, ids_restore ``` ::: ### ==解碼器(decoder)== - 傳入參數為x, ids_restore - x - ids_restore - 圖塊遮蔽(encoding)與還原時(decoding)的位置指標 ```python= def forward_decoder(self, x, ids_restore): ... # remove cls token x = x[:, 1:, :] return x ``` #### 初始化建立解碼器(decoder) ::: spoiler MAE decoder specifics ```python=46 # MAE decoder specifics self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch ``` ::: #### 解碼器前向傳播 - decoder與encoder不對稱,深度與寬度都較小 - 解碼前先把mask tokens加回sequence,亦即解碼回去圖像時的input是包含MASK圖塊 ::: spoiler forward_decoder() ```python=172 def forward_decoder(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] ``` ::: ### 損失函數計算與前向傳播 ==forward_loss()== - MSE Loss - 先將原始影像區塊後,計算每個patch的loss後取平均 - norm_pix_loss : 將原始影像以patch為單位做標準化 - 再取出被遮蔽的patch,計算被遮蔽patch部分的loss ::: spoiler code ```python= def forward_loss(self, imgs, pred, mask): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss ``` ::: #### 參考資料 - pytorch函式 - gather() > 对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index > - 定义:从原tensor中获取指定dim和指定index的数据 > 从完整数据中按索引取值 > - 用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的 - 引用自: - [torch.gather() 和torch.sactter_()的用法简析](https://blog.csdn.net/Teeyohuang/article/details/82186666) - [图解PyTorch中的torch.gather函数](https://zhuanlan.zhihu.com/p/352877584) - repeat()和 expand() - expand > 函数对返回的张量不会分配新内存,即在原始张量上返回只读视图,返回的张量内存是不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。 > 扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。 - repeat > torch.repeat用法类似np.tile,就是将原矩阵横向、纵向地复制。与torch.expand不同的是torch.repeat返回的张量在内存中是连续的。 > 沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据 - 引用自: - [PyTorch学习笔记——repeat()和expand()区别](https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html) - [pytorch torch.expand和torch.repeat的区别详解](https://www.jb51.net/article/173582.htm) ## Deep Learning相關筆記 ### Self-supervised Learning - [[Self-supervised] Self-supervised Learning 與 Vision Transformer重點筆記與近期發展](https://hackmd.io/7t35ALztT56STzItxo3UiA) - [[Time Series] - TS2Vec(Towards Universal Representation of Time Series) 論文筆記](https://hackmd.io/OE9u1T9ETbSdiSzM1eMkqA) ### Object Detection - [[Object Detection_YOLO] YOLOv7 論文筆記](https://hackmd.io/xhLeIsoSToW0jL61QRWDcQ) ### ViT與Transformer相關 - [[Transformer_CV] Vision Transformer(ViT)重點筆記](https://hackmd.io/tMw0oZM6T860zHJ2jkmLAA) - [[Transformer] Self-Attention與Transformer](https://hackmd.io/fmJx3K4ySAO-zA0GEr0Clw) - [[Explainable AI] Transformer Interpretability Beyond Attention Visualization。Transformer可解釋性與視覺化](https://hackmd.io/SdKCrj2RTySHxLevJkIrZQ) - [[Transformer_CV] Masked Autoencoders(MAE)論文筆記](https://hackmd.io/lTqNcOmQQLiwzkAwVySh8Q) ### Autoencoder相關 - [[Autoencoder] Variational Sparse Coding (VSC)論文筆記](https://hackmd.io/MXxa8zesRhym4ahu7OJEfQ) - [[Transformer_CV] Masked Autoencoders(MAE)論文筆記](https://hackmd.io/lTqNcOmQQLiwzkAwVySh8Q)