changed a year ago
Published Linked with GitHub

[Transformer_CV] Masked Autoencoders(MAE)論文筆記

tags: Literature Reading Vision Transformer Paper

AI / ML領域相關學習筆記入口頁面


論文概覽

Masked Autoencoders Are Scalable Vision Learners

Encoder架構為Vision Transformer(ViT)
原始論文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

  • 在NLP領域中,基於掩蔽自編碼(Masked Autoencoder)的自監督預訓練取得巨大的成功(BERT),而掩蔽自編碼器是一種更通用的去噪自編碼器(Denoised Autoencoder),也適用於計算機視覺,盡管如此,視覺自編碼方法的發展為何還是落後於 NLP?

  • 本篇論文從以下幾點回答:
    1、CV 和 NLP 主流架構不同:

    • 直到 ViT (2020.12) 出現之前,CV 的主流架構一直是以卷積網路(CNN)為主,NLP 的主流架構一直是以 Transformer 為主。CNN是以一個固定窗口(filter)滑過圖片,本質上與Transformer所產生token不同

    2、語言和圖片 (視頻) 的信息密度不同:

    • 語言是經過人類大腦組織過、高訊息密度的訊息,是富含高度語意資訊的。而圖像或影片是自然產生的訊息,在空間上是高度冗餘的。
    • 例如,擋住圖片的一部分,可以很容易地通過從周圍而想像出它的樣子來,但如果遮住一個句子中的單詞,則較難回推。

    3、 Decoder部分在 CV 和 NLP 中充當的角色不同:

    • 在 CV 領域,Decoder 的作用是重建影像,Decoder輸出的語意(訊息)級別是低階的。
    • 在 NLP 領域,Decoder 的作用是重建單詞,Decoder輸出的是高階、富含資訊的語義級別
  • 為了克服這種差异並鼓勵學習有用的特征,論文中展示了:一個簡單的策略在計算機視覺中也能非常有效:掩蔽很大一部分隨機 patch。這種策略在很大程度上减少了冗餘,並創造了一個具有挑戰性的自監督任務,該任務需要超越低級圖像統計的整體理解。

模型架構

  • 訓練階段
    • 在預訓練期間,圖像大部分區域(例如75%)被掩蓋掉。
    • 編碼階段: 以一個較大的編碼器(Encoder)對未被遮蔽的圖塊(patch)進行編碼
    • 解碼階段: 依據位置訊息,將全部的圖塊(包含被遮蔽的),以一個小的解碼器(decoder)來處理,以像素重建原始圖像。
    • 損失函數:
      • 目標是原始與生成圖像間image pixels層級的均方誤差(Mean squared error)最小化
  • Fine-tune(針對下游特定任務的遷移學習)
    • 經過預訓練。解碼器被丟棄,編碼器被應用於未被破壞的圖像(所有圖塊)進行識別等其他下游任務。

MAE 方法嚴格來講屬於一種去噪自編碼器 (Denoising Auto-Encoders (DAE)),去噪自動編碼器是一類自動編碼器,它破壞輸入信號,並學會重構原始的、未被破壞的信號。 MAE 的 Encoder 和 Decoder 結構不同,是非對稱式的。 Encoder 將輸入編碼為 latent representation,而 Decoder 將從 latent representation 重建原始信號。

MAE 和 ViT 的做法一致,將圖像劃分成規則的,不重疊的 patches。然後按照均勻分佈不重複地選擇一些 patches 並且 mask 掉剩餘的 patches。作者採用的 mask ratio 足夠高,因此大大減小了 patches 的冗餘信息,使得在這種情況下重建 images 不那麼容易。

參考資料:

學習資源

核心程式碼筆記

初始化與整體前向傳播流程

  • norm_layer(embed_dim) 即nn.LayerNorm()為將整層進行標準化
    • 相對於BatchNorm為將整個批次的資料標準化
    • 標準化的對象為被切成多個1維圖塊(序列)的影像,標準化的單位為每個1維扁平化後的圖塊
  • output: return loss, pred, mask
    • 這裡的pred是一個由多個1維圖塊組成的sequence
      • shape : (N, L, embed_dim)
        • N是batch size,L是patch數量,embed_dim是指定的每個圖塊token的內嵌向量維度
    • 預測結果需要去區塊化還原成影像
      • model.unpatchify(pred)
code
from functools import partial import torch import torch.nn as nn from timm.models.vision_transformer import PatchEmbed, Block from util.pos_embed import get_2d_sincos_pos_embed class MaskedAutoencoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): super().__init__() # --- skip --- def forward(self, imgs, mask_ratio=0.75): latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] loss = self.forward_loss(imgs, pred, mask) return loss, pred, mask

MAE的功能函式

將圖像區塊化 & 去區塊化

  • imgs: [N, 3, H, W]
  • patch_embedding: [N, L, p x p x c]
  • 轉換2維影像成為L個1維序列
    • 每個序列長度: p x p x c
      • eg: patch size = 16 x 16。 channel=3。
        每個一維的序列長度: 16 x 16 x 3 = 768
    • L (num_patches) = img_h x img_w / p x p
      • 當輸入影像長寬為 224 x 224, Patch size = 16x16
        則有 (224 x 224) / (16 x 16) = 14 x 14 = 196個patch
      • 相當於NLP(自然語言處理)領域的sentence_length
code
def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return

隨機遮蔽圖像

random_masking()

  • 將圖像轉換成L(patch number)個1為的序列後,隨機抽樣決定那些圖塊(1d patch)要被遮蔽

  • input:

    • imgs
      • imgs (N, 3, H, W)-> x (N, L, dim)
        • dim : patch_size**2 * c
        • N是batch
  • output: return x_masked, mask, ids_restore

    • x_masked
      • 沒被遮蔽的1d圖塊(unmasked 1d patch)
      • 透過ids_keep內的指標,取得欲保留(未被遮蔽)的path
      • shape : (N, L*(1-maske_rario), dim)
    • mask
      • shape : (N, L),值為[0, 1]的遮罩。0 is keep, 1 is remove
        • N是batch size,L是patch數量 * 被遮蔽的比例
      • torch.gather(mask, dim=1, index=ids_restore) 根據ids_restore指標位置對mask重新取值/排序
    • ids_restore
      • 取得由小到大的排序index,作為圖塊遮蔽(encoding)與還原時(decoding)的位置指標
      • ids_shuffleids_restore 取得的index值其實是一樣的
        • shape : (N, L)
  • 沒被遮蔽的圖塊(函式中的x_masked)在def forward_encoder()中,會加上cls token送入Transformer blocks中訓練

    code
    ​​​​ def random_masking(self, x, mask_ratio): ​​​​ """ ​​​​ Perform per-sample random masking by per-sample shuffling. ​​​​ Per-sample shuffling is done by argsort random noise. ​​​​ x: [N, L, D], sequence ​​​​ """ ​​​​ N, L, D = x.shape # batch, length, dim ​​​​ len_keep = int(L * (1 - mask_ratio)) ​​​​ noise = torch.rand(N, L, device=x.device) # noise in [0, 1] ​​​​ # sort noise for each sample ​​​​ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ​​​​ ids_restore = torch.argsort(ids_shuffle, dim=1) ​​​​ # keep the first subset ​​​​ ids_keep = ids_shuffle[:, :len_keep] ​​​​ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) ​​​​ # generate the binary mask: 0 is keep, 1 is remove ​​​​ mask = torch.ones([N, L], device=x.device) ​​​​ mask[:, :len_keep] = 0 ​​​​ # unshuffle to get the binary mask ​​​​ mask = torch.gather(mask, dim=1, index=ids_restore) ​​​​ return x_masked, mask, ids_restore

編碼器(encoder)

def forward_encoder(): return return latent, mask, ids_restore

Encoder(ViT)內的Embedding

類別token

cls_tokens

  • shape of : (1, 1, embed_dim) > (N, L, embed_dim)
  • 這裡內容值均為0,由於沒有對應的分類任務訓練,因此MAE中的cls_tokens實際上並沒有學習到實質內容,此階段保留僅為符合ViT模型架構
  • 在後續下游任務遷移學習時,仍可透過訓練學習到對應任務所需的權重
位置編碼

pos_embed

  • shape : (batch, num_patches+1, embed_dim)

    • pos_embed.shape[-1] = embed_dim = 1024
    • int(patch_embed.num_patches**.5) = sqrt(196)=14
    code
    ​​​​ def initialize_weights(self): ​​​​ # initialization ​​​​ # initialize (and freeze) pos_embed by sin-cos embedding ​​​​ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) ​​​​ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) ​​​​ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) ​​​​ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

初始化建立編碼器(encoder)

  • 位置編碼中的self.pos_embed 中第二維度的num_patches + 1, 這個1是ViT架構中的類別編碼,加在最前面

    code MAE encoder
    ​​​​ # MAE encoder specifics ​​​​ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) ​​​​ num_patches = self.patch_embed.num_patches ​​​​ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) ​​​​ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding ​​​​ self.blocks = nn.ModuleList([ ​​​​ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) ​​​​ for i in range(depth)]) ​​​​ self.norm = norm_layer(embed_dim)

編碼器前向傳播流程

  • 流程:

    1. x = self.patch_embed(x)
      • 將圖像切割為依照 p * p的大小切分並線性轉換為L個序列(1維圖塊)
      • shape of patch_embed(x) : (batch, num_patches, embed_dim)
      • 每個一維圖塊長度是 patch_size**2 * c
    • 與patchify(imgs)的形狀相同
    1. x = x + self.pos_embed[:, 1:, :]

      • 加上(移除類別項目的)位置編碼
      • pos_embed[:, 0, :]是類別的編碼,因此從1開始取值為捨棄類別編碼
    2. 進行隨機遮蔽self.random_masking(x, mask_ratio)

      • code中為了進行隨機遮蔽的操作,於前後分別移除與加回位置編碼的類別項,主要為了便於隨機遮蔽的操作
      • 這邊使用的ViTransformer核心為用於分類任務的架構,前面自帶類別項目
      • 這是MAE在encoder部分的精妙處,其他部分操作核心與ViT一致
    3. x = torch.cat((cls_tokens, x), dim=1)

      • 把類別編碼(cls_tokens)加回未被遮蔽的1維圖塊
      • cls_tokens 關於MAE內的類別token
        • shape of cls_tokens : (1, 1, embed_dim) > (N, L, embed_dim)
    4. 經過上述處理後把帶有類別與位置編碼訊息的1維圖塊們送入Transformer模型(self attention)內

      • shape of x : [batch, num_patches + 1, embed_dim]
      • 這邊的blocks用的是timm library寫好的self-attention模組
        • Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(depth)])
      ​​​​​​​​# apply Transformer blocks ​​​​​​​​for blk in self.blocks: ​​​​​​​​ x = blk(x) ​​​​​​​​x = self.norm(x)
    forward_encoder()
    ​​​​ def forward_encoder(self, x, mask_ratio): ​​​​ # embed patches ​​​​ x = self.patch_embed(x) ​​​​ # add pos embed w/o cls token ​​​​ x = x + self.pos_embed[:, 1:, :] ​​​​ # masking: length -> length * mask_ratio ​​​​ x, mask, ids_restore = self.random_masking(x, mask_ratio) ​​​​ # append cls token ​​​​ cls_token = self.cls_token + self.pos_embed[:, :1, :] ​​​​ cls_tokens = cls_token.expand(x.shape[0], -1, -1) ​​​​ x = torch.cat((cls_tokens, x), dim=1) ​​​​ # apply Transformer blocks ​​​​ for blk in self.blocks: ​​​​ x = blk(x) ​​​​ x = self.norm(x) ​​​​ return x, mask, ids_restore

解碼器(decoder)

  • 傳入參數為x, ids_restore
    • x
    • ids_restore
      • 圖塊遮蔽(encoding)與還原時(decoding)的位置指標
def forward_decoder(self, x, ids_restore): ... # remove cls token x = x[:, 1:, :] return x

初始化建立解碼器(decoder)

MAE decoder specifics
# MAE decoder specifics self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch

解碼器前向傳播

  • decoder與encoder不對稱,深度與寬度都較小
  • 解碼前先把mask tokens加回sequence,亦即解碼回去圖像時的input是包含MASK圖塊
forward_decoder()
def forward_decoder(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :]

損失函數計算與前向傳播

forward_loss()

  • MSE Loss
  • 先將原始影像區塊後,計算每個patch的loss後取平均
  • norm_pix_loss : 將原始影像以patch為單位做標準化
  • 再取出被遮蔽的patch,計算被遮蔽patch部分的loss
code
def forward_loss(self, imgs, pred, mask): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss

參考資料

  • pytorch函式
    • gather()

      对于out指定位置上的值,去寻找input里面对应的索引位置,根据是index

      • 定义:从原tensor中获取指定dim和指定index的数据
        从完整数据中按索引取值
      • 用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的
    • 引用自:

    • repeat()和 expand()

      • expand

        函数对返回的张量不会分配新内存,即在原始张量上返回只读视图,返回的张量内存是不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。
        扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小(size)等于1的维度扩展到更大的尺寸。

      • repeat

        torch.repeat用法类似np.tile,就是将原矩阵横向、纵向地复制。与torch.expand不同的是torch.repeat返回的张量在内存中是连续的。
        沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据

      • 引用自:

Deep Learning相關筆記

Self-supervised Learning

Object Detection

ViT與Transformer相關

Autoencoder相關

Select a repo