Try   HackMD

A Simple Framework for Contrastive Learning of Visual Representations

tags: Paper 2020 AI CV Unsupervised learning Architecture

Abstract

  • 提出 SimCLR 的 self-supervised learning 架構。
    • without memory bank.
  • 三個重要發現:
    1. Data augmentation 在預測任務中扮演了很重要的角色。
    2. 透過在 representation 和 contrastive loss 中加入 learnable nonlinear transform 可以加強 representation 的學習
    3. 在 contrastive learning 中,增加 batch size 和 training step 比起 supervised learning 能夠取得更好的效果。
  • 透過 fine-tuned 1% 的 label,可以達到比一些 supervised learning 的 model 更好的效果。

Introduction

  • 學習不用人類自行 label 的視覺資訊是一個重要且長久的問題。
  • 目前主要作法有 2 種 : Generative, Discriminative
    1. Generative
      生成 pixel-level 的圖像。但是耗算的資源很大,且 representation 可能不需要這麼多的資訊。
    2. Discriminative
      使用類似 supervised learning 的 objective function,透過啟發式的方法從 unlabeled dataset 進行 label,但可能會對 generality 產生影響
      最近基於在 latent space 的 contrastive learning 的這種方法有了很大的進展。
  • SimCLR 的重點:
    1. 通過組合多個 Data Augmentation 可以對 contrastive learning 產生重大的影響,並且 contrastive learning 對強的 data augmentation 產生的益處大於 supervised learning
    2. 在 representation 和 contrastive loss 中間加入 non-linear transform 可以增強 representation 學習的成效。
    3. Contrastive cross entropy loss 受益於 normalized embeddingtemperature parameter.
    4. Contrastive learning 比起 supervised learning,更明顯受益於大的 batch size 與訓練長度,且跟 supervised learning 一樣,更深更寬的網路都可以加強效能
  • SimCLR 為 SOTA 的 semi-supervised, self-supervised learning 方法。
  • 透過 fine-tuned,SimCLR 在 12 個數據集中的 10 個比一些厲害的 supervised 方法效能更好。

Method

  • SimCLR 透過 contrastive loss,最大化 latent space 中同一個資料不同 augmentation 的資料的一致性來學習 representation。
  • Image Not Showing Possible Reasons
    • The image file may be corrupted
    • The server hosting the image is unavailable
    • The image path is incorrect
    • The image format is not supported
    Learn More →


    hi=f(x~i)=ResNet(x~i)

    zi=g(hi)=W(2)σ(W(1)hi),σ=ReLU

    使用
    zi
    來定義 contrastive loss 比用
    hi
    來得好
  • 方法步驟:
    1. Random sample minibatch
      N
      .
    2. 為 minibatch 內的每一筆 data 做 data augmentation,得到
      2N
      筆 data.
    3. 選定一筆資料和它 data augmentation 的結果作為 positive pair,此時對於選定的資料,在
      2N
      筆資料內,有
      1
      個 positive,
      2(N1)
      個 negative.
      • Let
        sim(u,v)=uv/uv
        , cosine similarity.
      • Loss for single positive pair (i, j) : NT-Xent (Normalized Temperature-scaled Cross Entropy)
        li,j=logexp(sim(zi,zj)/τ)Σk=12N1[ki]exp(sim(zi,zk)/τ)
      • τ:
        temperature parameter
      • sim:


        Image Not Showing Possible Reasons
        • The image file may be corrupted
        • The server hosting the image is unavailable
        • The image path is incorrect
        • The image format is not supported
        Learn More →


        BC 向量相似,AB 向量不相似
      • 最好的
        loss
        li,j
        log1=0

        exp(sim(zi,zj)/τ:
        代表正例間的相似度,最大化此項
        Σk=12N1[k1]exp(sim(zi,zk)/τ):
        代表負例間的相似度,最小化此項
      • sim(zi,zj)/τ=uv/uvτ

        透過
        τ
        ,可將 cosine similarity 的範圍縮小 [-1, 1]
        [-1/
        τ
        , 1/
        τ
        ]
      • exp(sim(zi,zj)/τ)

        由於當 batch size 很大時, NT-Xent 的分母會非常大,從而導致結果過大,因此透過
        τ,exp
        的搭配可以讓分母(負例)縮小


        Image Not Showing Possible Reasons
        • The image file may be corrupted
        • The server hosting the image is unavailable
        • The image path is incorrect
        • The image format is not supported
        Learn More →


        ex.
        logexp(sim(zi,zj))Σk=1100exp(sim(zi,zk))=loge1e100=loge

        logexp(sim(zi,zj)/τ)Σk=1100exp(sim(zi,zk)/τ)=loge1/100e100/100=log1=0

        or

        logexp(sim(zi,zj)/τ)Σk=1100exp(sim(zi,zk)/τ)=loge1/100e50/100=log1e1/2=log0.6=0.22
  • 透過把 batch size 從 256 提升到 8192 以避免 memory bank。
    • 使用這麼大的 batch size 可能會令 SGD, Momentum 這些使用 linear LR 的 optimizer 沒辦法穩定的運作,因此採用 LARS(Layer-wise Adaptive Rate Scaling) optimizer.
      • 為了使訓練加速,通常會使用 batch 來加快訓練。然而 batch size 盲目地增加可能會造成 (1) 參數修正緩慢 (2) Batch size 過大,下降的方向已不再變化。下式為網路迭代的公式:
        wt+1=wtη1nΣxβl(x,wt)

        可知,LR 與 batch size 成反比,其中,LR 影響 model 的收斂狀態,batch size 影響 model 的泛化性能(Generalization Gap)
        為了解決這個問題,許多人提出當 batch size 增大 k 倍時,LR 也增加 k 倍,然而這樣的方法會造成 LR 太大,訓練不穩定,例如在早期階段 LR 太大可能在錯誤的方向上更新很多,導致最終模型成效很差。
      • 將上式簡化為:
        wt+1=wtλL(wt)

        當 LR 過大時,可能會導致
        λL(wt)>wt
        ,造成發散。作者通過分析每個 layer 的
        wt/L(wt)
        ,發現 (1) 每個 layer 的差異很大,以及(2) 訓練的早期階段比值都比較大。warm-up 透過以較小的 LR 開始訓練解決問題 (2),LARS 透過每個 layer 根據自己的情況調整 LR,解決問題 (1),local LR 的計算方式如下:
        λl=η×wlL(wl)

        其中
        η
        是 hyperparameter(通常為0.001),代表更新時會改變參數的置信度(trust),推測為比值正常為 1000 左右的數值。
        有了 local LR 後,就可以替換每個 layer 的 global LR:
        wtl=γ×λl×L(wtl)

        其中
        γ
        代表 global LR。
  • 實驗設置
    • Data Augmentation 方法
      1. Random crop & resize (with random flip)
      2. Color distortions
      3. Gaussian blur
    • Backbone
      1. Base encoder: ResNet50
      2. Projection: 2 layer MLP (project representation to 128D latent space)
    • Loss
      • NT-Xent
    • Optimizer
      1. LARS, LR = 4.8 (0.3
        ×
        batch size / 256)
      2. Weight decay:
        106
      3. Linear warm-up for 10 epochs
      4. Cosine decay schedule
    • Batch size
      • 4096
    • Epoch
      • 100

Data Augmentation for Contrastive Representation Learning

  • Data augmentation 並沒有成為 contrastive learning 的標配。許多現有的方法通過使用不同的網路架構達成 contrastive learning。
    • 有人透過上下文 (鄰近圖片) 做 contrastive learning
      Data-Efficient Image Recognition with Contrastive Predictive Coding

      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →


      在我的碩論中可透過時間軸的前後替代 data augmentation,時間軸越靠近目標,loss 越小
  • 這些複雜的架構其實是 data augmentation 的一個 subset,因此可以透過簡單的 data augmentation 達成。
    • 並且利用這種方法能夠將 prediction task 跟其他 task 解耦。
  • 由於 ImageNet 的每張圖片大小都不一樣,所以在實驗中採用了 asymmetric data transform,主要概念就是先使用 crop & resize,再使用主要的 transformation 方法,要注意這種 asymmetric 的方法是會影響性能的
  • 從實驗中發現,沒有一種單一的 data transform 的方法可以學習到足夠的 representation,隨著組合方法的增加,預測任務逐漸變難,而表徵也變得更好

    Image Not Showing Possible Reasons
    • The image file may be corrupted
    • The server hosting the image is unavailable
    • The image path is incorrect
    • The image format is not supported
    Learn More →


    在碩論中,除了使用時間軸的前後外,還需要尋找其他 data augmentation 方法
  • 由於同一個圖像的顏色分佈大部分都是相同的,因此若是僅使用 random crop,會造成 NN 通過觀察顏色分佈這種奧步識別圖像。解決方法是透過將 random crop 與 color distortion 結合。

Architecture for Encoder and Head

  • 經過實驗表明,unsupervised learning 比起 supervised learning,更能夠受益於更大的 model.
  • 使用 nonlinear projection
    g()
    可以讓 projection 前的 representation 學習的更好
    • nonlinear > linear >> none
    • 通過 nonlinear projection 後的 representation 比通過之前的 representation 效果更差
      • 推測原因是因為
        g()
        之前的
        h
        是學習到了影像的特徵,而後由於 contrastive loss 的緣故,nonlinear projection 被訓練成對 data transformation 不變的 MLP,也因為這樣,一些對 downstream 有用的訊息被移除掉了。
      • 為了驗證這個想法,通過使用
        h
        以及
        g(h)
        來預測 data transformation,實驗結果如下圖:

        Image Not Showing Possible Reasons
        • The image file may be corrupted
        • The server hosting the image is unavailable
        • The image path is incorrect
        • The image format is not supported
        Learn More →


        可以看到
        h
        比起
        g(h)
        保有更多對於 data augmentation 的資訊。

Loss Functions and Batch Size

5.1 看不懂,我是大便。

  • Batch size 在 epoch 小的時候增大對訓練成效有明顯改善,隨著 epoch 變大,只要 batch 是重新隨機採樣的,batch size 的影響就會逐漸減少甚至消失。這是因為 contrastive learning 的 batch size 越大,負例就越多,從而收斂得更快,而隨著 epoch 變多,被 random sample 到的負例自然也越多

Comparison with State-of-the-art

  • 透過將 ResNet50 hidden layer 的 width 調大,能夠增加性能,如下圖所示:

    Image Not Showing Possible Reasons
    • The image file may be corrupted
    • The server hosting the image is unavailable
    • The image path is incorrect
    • The image format is not supported
    Learn More →

  • 使用 Semi-supervised learning fine-tuned 1%, 10% 的 label 一樣能夠增強效能。
  • 對不同 dataset 的 transfer learning 做了 2 種實驗,分別是 linear evaluation 和 fine-tuned,結果如下圖所示:

    Image Not Showing Possible Reasons
    • The image file may be corrupted
    • The server hosting the image is unavailable
    • The image path is incorrect
    • The image format is not supported
    Learn More →

  • Handcrafted pretext tasks:
    • 近來 SSL 開始變得熱門是因為一些 heuristic 的方法,諸如旋轉、relative patch prediction 等等,但是這種方法缺少了一定的泛化性
  • Contrastive visual representation learning:
    • 透過正例與負例進行學習,以前的方法有將每個 feature vector 視為獨立的 class、memory bank 等,最近則是透過 in-batch memory bank 來實做。

Conclusion

  • 使用了一個簡單的框架實現效果良好的 SSL 方法。
  • 與 supervised learning 相比,不同之處在於對 data augmentation 方法的選擇、nonlinear projection 及 loss function.

Appendix

A. Data Augmentation Details

  • Random crop and resize to 224x224
    • 使用 incception-style random cropping,crop 原始圖像的 0.08 ~ 1.0,長寬比為 3/4 ~ 4/3。
    • crop 完後以 50% 的機率 rotation。這個過程是非必要的,僅會對效能造成一點下降。
  • Color distortion
    • 使用 color jittering & color dropping,越強的 color jittering 通常會有更大的幫助,因此設置了一個 strength parameter.
      • Color jittering: 對亮度、對比、飽和、色調進行調整。
      • Color dropping: 以一定機率將圖片變為灰階。
  • Gaussian blur
    • 小幅增進效能。以 50% 的機率進行模糊,隨機選擇
      σ[0.1,2.0]
      ,kernel size 為圖像的 10%.

B. Additional Experimental Results

  • Batch Size and Training Steps
    • linear scaling LR 在 SGD, Momentum 中是很好的方法,但是在 LARS 中,square root LR 是更好的作法,因此,不同於 linear scaling 的
      LR=0.3×BatchSize/256
      ,而是使用了 square root 的
      LR=0.075×BatchSize
      ,詳細的實驗數據如下:



      從表中可以看到,square root LR 在較小的 batch size 和 training step 下,可以更好的提昇模型效能。
    • Batch size 的 gap 在 8192,但是 epoch 則是還沒有收斂,能夠持續提昇 accuracy.
  • Broader composition of data augmentations further improves performance
    • 套用更多的 data augmentation 方法可以進一步加強模型效能。
  • Effects of Longer Training for Supervised Models
    • 當在 supervised learning 中套用更長的 training step 和 data augmentation 種類時,更長的 training step 並不會讓 model 變好,data augmentation 會讓模型稍有起色,但並不明顯,且不一定適用於任何一種模型 (ex. ResNet, ResNet(4
      ×
      ))
  • Understanding The Non-Linear Projection Head
    • z=g(h)=Wh,WR2048×2048

      從實驗中可以發現,
      W
      的 eigenvalue 的分佈,可以看到很少大的 eigenvalue,也就意味著
      W
      估計是 low-rank 的,實驗圖如下:


      • Low-rank:low-rank 代表 matrix 的每個向量是密切相關的,由於這個緣故,low-rank 的 matrix 可以投影到較低的維度。在線性代數中,可以透過將 matrix 轉換為階梯形式後,看有幾行非 0 向量得到 rank 數值。
      • 使用 t-SNE 視覺化之後,可以發現
        h
        分類的效果比
        g(h)
        更好,如下圖所示:


  • Transfer Learning
    • 使用 linear evaluation, fine-tuned 兩種方法。
    • Transfer learning via a Linear Classifier
      • 使用 frozen pre-trained network 提取到的 feature 丟到
        l2regularized
        的 multinomial logistic regression classifier 進行訓練。
    • Transfer Learning via Fine-Tuning
      • 將整個 pre-trained network 的 weights 當作初始值,將整個 network 微調。
  • CIFAR-10
    • 由於 CIFAR-10 的圖像比 ImageNet 的圖像要小很多,因此將 ResNet50 第一個 7x7, stride=2 的 conv 改成 3x3, stride=1 的 conv,並一除了第一個 max pooling.
    • LR=0.5,1.0,1.5

      Temperature=0.1,0.5,1.0

      BatchSize=256,512,1024,2048,4096
    • 從實驗中可以看到,當訓練到收斂時,最好的 temperature 是 0.5,當 batch size 增加時,0.1 的 temperature 的表現也越來越好,下圖為實驗結果: