# 【論文筆記】DOT: A Distillation-Oriented Trainer 論文連結:https://arxiv.org/abs/2307.08436 發表於 ICCV 2023 ## Introduction 這篇研究探討 knowledge distillation 在 optimization 過程中所出現的特性,並且提出 DOT,一個新的 knowledge distillation optimization 方法。 ![截圖 2024-02-29 上午1.55.26](https://hackmd.io/_uploads/HJjq7x6np.png) Knowledge distillation 用來將知識從比較大的模型 (teacher model) 轉移到比較小的模型上 (student model),典型的 knowledge distillation objective function 通常包含兩個部分:task loss (e.g. cross-entropy loss) 以及 distillation loss (e.g. KL divergence),亦即: $$ \mathcal{L} = \alpha \mathcal{L}_\text{CE}(x, y; \theta) + (1-\alpha) \mathcal{L}_\text{KD}(x, \phi; \theta) $$ Task loss 用來衡量和真實數據之間的差異,distillation loss 則是用到衡量 student model 和 teacher model 兩者輸出的差異。 作者經過研究發現,對於一個圖片分類的任務,加入 distillation loss 可以幫助 student 收斂到一個更平坦的 minimum,因此 student model 可以有更好的 generalization 能力(上圖中)。然而他們還同時發現加入 distillation loss 反而會讓 task loss 無法完全收斂,和只使用 cross-entropy loss 訓練的 baseline 相比,會出現一個 trade-off(上圖右)。 為何會出現 trade-off 呢?他們認為可能的原因是因為在普遍實作中,task loss 和 distillation loss 會透過相加合併在一起,使得學習的目標變成是一個 multi-task learning,模型此時會嘗試在兩個 loss 當中找取平衡。 已知 teacher model 因為模型大小的關係,總是可以得到比 student 更低的 task loss,因此他們猜想,如果 distillation loss 可以**被充分的最佳化**,就等同於讓 student model 和 teacher model 更相似,那麼 task loss 也會被降低。他們提出了這篇研究最重要的核心假設:充分 optimize distillation loss,便可以打破 trade-off。 研究提出了 DOT,讓 distillation loss 可以在 knowledge distillation 的過程中支配 task loss,實驗結果發現他們的方法的確讓 knowledge distillation 得到更好的結果,並且他們的研究也是首次嘗試以 optimization 的角度來解析 knowloedge distillation 機制。 ## Methods DOT 的主軸想法是利用 momentum 來控制 optimization 的方向,來達到讓 distillation loss 支配梯度的效果。任何有使用 momentum 的 optimization 方法都可以利用 DOT,論文當中選擇以 Stochastic Gradient Descent (SGD) optimizer 作為實驗對象。 ### Optimizer with Momentum SGD (with momentum) 在更新模型參數時,會同時使用到當前的梯度 $g$ 以及過往的梯度紀錄,具體來說,SGD 使用一個 momentum buffer $v$ 來紀錄梯度,每一個 training iteration,都更新一次 momentum buffer: $$ v \leftarrow g + \mu v $$ 並更新模型參數: $$ \theta \leftarrow \theta - \gamma v $$ 其中 $\mu$ 表示 momentum coefficient。使用 momentum 時因為利用到了歷史的梯度紀錄,可以加速模型的收斂。 ### Distillation Oriented Trainer ![截圖 2024-02-29 上午2.42.43](https://hackmd.io/_uploads/rkg2Agp36.png) DOT 將 task loss 和 distillation loss 分開來考慮,給兩者所得到的梯度設定獨立的 momentum,控制 optimization 的方向。設定兩個 momentum buffer 為 $v_\text{ce}$ 和 $v_\text{kd}$,這裡引入了一個新的 hyperparameter $\Delta$,每一個 training iteration,更新 momentum buffer 的方式令為: $$ v_\text{ce} \leftarrow g_\text{ce} + (\mu-\Delta) v_\text{ce} $$ $$ v_\text{kd} \leftarrow g_\text{kd} + (\mu+\Delta) v_\text{kd} $$ 更新模型參數則是透過 $$ \theta \leftarrow \theta - \gamma (v_\text{ce}+v_\text{kd}) $$ 可以看到 DOT 對 distillation loss 使用比較大的 momentum,而對 task loss 使用比較小的 momentum,在這樣的設定下,每次更新模型的時候都考慮比較多 ditillation loss 提供的梯度,表示 optimization 的方向能夠被 ditillation loss 的梯度所支配,也就可以更充分讓 distillation loss 收斂。 ## Experiments ### Settings 實驗中使用到的 datasets 包含 CIFAR-100、Tiny-ImageNet 和 ImageNet-1k,但多數的實驗使用到的是 CIFAR-100,因此以下實驗結果都是以 CIFAR-100 為對象。參數設定的部分,在 CIFAR-100 上他們發現 $\mu$ 值設 0.09,$\Delta$ 設定 0.05 到 0.075 可以有最好的結果。 ### Motivation Validations 以下實驗驗證他們的猜想正確,並且提出的方法是有用的: #### Does KD Loss dominate the optimization? 為了證明 DOT 的確使 distillation loss 支配 optimization,他們分析了 optimization 過程中梯度的變化。下圖分別是 distillation loss 和 total loss 的 momentum buffer 中梯度的 cosine similarity,以及 task loss 和 total loss 梯度之間的 cosine similarity。他們發現使用 DOT,distillation loss 和 total loss 兩者梯度的 cosine similarity 會隨著迭代大幅增加,表示 momentum buffer 中的梯度的確受到 distillation loss 支配。 ![截圖 2024-02-29 上午11.10.45](https://hackmd.io/_uploads/BJeaBOa2p.png =500x) #### Does KD and Task Losses converge better? 下圖視覺化了 task loss 和 distillation loss 在不同方法下的 loss curve。可以看到 DOT 可以同時讓 task loss 和 distillation loss 都達到更低,因此充分 optimize distillation loss 可以降低 task loss 的假設是正確的。 ![截圖 2024-02-29 上午11.18.18](https://hackmd.io/_uploads/r18tPuT3T.png =500x) #### Are independent momentums necessary? 為了證明實驗中表現的改善的確來自 DOT 獨立 momentum 的機制,下面的實驗探討不同的 $\Delta$ 數值對結果的影響。設定 $\Delta = 0$ 表示使用原始的 SGD 方法,$\Delta > 0$ 時表示 distillation loss 支配 optimization 過程,$\Delta < 0$ 則表示 task loss 支配 optimization。實驗結果顯示當 $\Delta > 0$,也就是 "distillation-oriented" 可以讓表現進步,反而 "task-oriented" 造成表現退步。 ![截圖 2024-02-29 上午11.48.19](https://hackmd.io/_uploads/HJRKRd6n6.png =500x) #### Are improvements attributed to tuning momentum? 下面實驗證明表現的提升是來自於 DOT 獨立 momentum 的設計,而非單純調整 momentum 的值得到的。使用原始 SGD 並實驗在 CIFAR-100 上不同 $\mu$ 數值的影響,結果顯示微調 $\mu$ 值的確能造成很些微的進步,但進步幅度不如使用 DOT 來得大。 ![截圖 2024-02-29 下午2.35.16](https://hackmd.io/_uploads/ByWhHjp2a.png =500x) ### Main Results 下表顯示 CIFAR-100 在各種 distillation 方法上的結果。將 DOT 應用在傳統 knowledge distillation 上,可以有 1.79% 的進步。為了實驗 DOT 是可泛化的,他們還將 DOT 實驗在其他不同類型的 distillation method 上,實驗結果也顯示 DOT 能夠應用在各種方法上。 ![截圖 2024-02-29 下午2.46.32](https://hackmd.io/_uploads/rJrUusphT.png) ### More Analysis 直觀而言,我們可能會覺得調整 loss function $$ \mathcal{L} = \alpha \mathcal{L}_\text{CE}(x, y; \theta) + (1-\alpha) \mathcal{L}_\text{KD}(x, \phi; \theta) $$ 當中的 $\alpha$ 數值,某種程度上可以讓 task loss 收斂更好,因此以下針對不同的 $\alpha$ 值進行實驗。但從實驗結果可以看到,調整 $\alpha$ 值對最後的 distillation 結果幾乎沒有影響,可以推斷單純調整 $\alpha$ 並不能像使用 DOT 一樣明顯提升 distillation 的結果。 ![截圖 2024-02-29 下午3.03.59](https://hackmd.io/_uploads/BJ0whiT3a.png =500x) ## Conclusion 研究發現引入 distillation loss 會限制 task loss 的收斂,產生 trade-off。作者推測充分 optimize distillation loss 可以避免這個問題,因此提出了 Distillation-Oriented Trainer (DOT),透過獨立的 momentum 來引導梯度。實驗結果發現 DOT 的確可以同時降低 distillation 和 task loss,並且可以套用在各種 distillation 方法上。