# Deep Mutual Learning 閱讀筆記 ###### tags: `Knowledge Distillation` 本論文與傳統學生模仿老師的蒸餾手法不同,提出一種==學生之間互相學習的架構 (peer-teaching)==。作者先提出直觀的解釋,說明為何 mutual learning 是可行的,而不是瞎子帶領瞎子 (the blind lead the blind):因為 conventional 的損失函數,每一個學生在訓練過程中的能力都會上升,且優化過程有一定方向性。訓練了一段時間後,每個模型都會對訓練集有差不多的預測值,但由於初始化權重不同,每個模型學到的表達及類別的機率分佈有所差異,正是這樣的特性,為蒸餾及 mutual learning 提供了額外的知識 >Each student is primarily directed by a conventional supervised learning loss, which means that their performance generally increases and they cannot drift arbitrarily into groupthink as a cohort. With supervised learning, all networks soon predict the same (true) labels for each training instance; but since each network starts from a different initial condition, they learn different representations, and consequently their estimates of the probabilities of the next most likely classes vary. It is these secondary quantities that provide the extra information in distillation as well as mutual learning. 在訓練過程中,每個學生網路會收集輸出標籤的機率分佈估計,藉由彼此分佈的相互對比,可以提昇每個學生網路的後驗熵(posterior entropy),使得網路可以學到更強健且通用的參數 作者在實驗中發現: 1. 相互合作學習的小網路,其表現優於靜態大型網路(static large network) - 小網路的學習架構 2. mutual learning 適用於多種模型架構,甚至不同大小的混合網路構成的異質結構 3. 學習效果隨著隊列中網路的數量增加而提昇 4. 即使目標是得到單一網路,但隊列中的網路也可用來 ensemble 成一個高效能的網路 Deep Mutual Learning (DML) 的網路架構如下 ![](https://i.imgur.com/NtNc5aa.png) 對於樣本 $x_i$ 屬於某個類別 $m$ 的機率及多樣本的交叉熵損失可表示如下 ![](https://i.imgur.com/TAehWRJ.png) ![](https://i.imgur.com/wADpdAj.png) 為了使得模型更通用化,我們引入另一個 peer network 提供其後驗機率的知識,並使用 KL 散度 評估兩者在後驗知識上相似的程度,損失函數可更新如下 ![](https://i.imgur.com/7n8wwE1.png) 由於 KL 散度的特性,兩個網路的 KL 損失並不相同,也可改用兩者的平均作為 KL 損失,但作者表明在經驗上來講,並沒有太大區別 ![](https://i.imgur.com/DrDoBc5.png) 由於 DML 網路是互相學習的,因此論文讓每個學生模型以相同的 mini-batches 進行學習,每個網路都會輸出自己的預測,讓其他網路可以計算 KL 散度,並一次更新全部的網路直到收斂。演算法及多網路學習的損失函數如下 ![](https://i.imgur.com/Mn4QCvd.png) ![](https://i.imgur.com/XN59s2J.png) 作者還提出另一種多模型的 KL 損失函數,但實驗顯示其表現較差,原因是由於先把各網路的後驗知識平均,等於把其他類別的機率分佈都平均掉了,反而突顯了正確標籤的機率 ![](https://i.imgur.com/1h8Bia8.png) - Results on CIFAR-100 以下為不同方法及網路結構在 CIFAR-10 & CIFAR-100 的表現,可以看到 DML 確實帶來提昇的效果 ![](https://i.imgur.com/olreUOg.png) - Results on Market-1501 作者比較有無使用 DML 的表現,每一個 MobileNet 在雙網路的隊列中訓練,並將結果平均,可以看到 DML 明顯優於單獨訓練的結果 ![](https://i.imgur.com/HarrEwq.png) - Results on ImageNet & Distributed Training of DML - 對比 Inception 及 MobileNet 有無使用 DML 的差異,可以看到使用 DML 的模型效果更好 - 右圖比較依照演算法 sequence 更新及分散平行訓練且同時更新的差異,顯示分散平行訓練的方式更好且兩個模型的表現更為一致,這是因為他們有相同訓練迭代次數,作者認為這個結果代表當訓練進程 (learning progress) 的差異被消除時,peer teaching 對模型有最佳的幫助 ![](https://i.imgur.com/9MOgHAF.png) - Comparison with Model Distillation Net1 為 teacher 模型,Net2 為 student 模型,結果顯示一個強大的預訓練模型不是必要的,兩者一同相互訓練對模型性能也有明顯的提昇 ![](https://i.imgur.com/WxX46Aw.png) - DML with Larger Student Cohorts 從結果可以看到越多網路相互訓練,會有更好的效能 ![](https://i.imgur.com/K2dlQ8R.png) - How and Why does DML Work? DML 可以幫助我們找到更 robust 且更通用的最小值 >DML Leads to Better Quality Solutions with More Robust Minima 在訓練過程中,DML 可以完美的 fit 訓練資料,此外在測試集上,表現比單獨訓練來得更加優異,因此 DML 不只幫助網路找到更好的最小值,還幫助網路找到更通用的最小值。論文嘗試對單獨訓練及 DML 模型加上高斯噪音,可以看到隨標準差越來越大,單獨訓練的網路損失急遽的上升,說明 DML 訓練方法的強健性 ![](https://i.imgur.com/2ASiAgP.png) 另外,因為 DML 要求網路之間要互相批配其預測的機率分佈,所以若機率分佈有所差異時,網路會受到懲罰,因此網路會將==更多的重心放在 secondary 的機率上,以及更多對 distinct secondary 機率的關注,代表網路除了藉由主要概率,還透過 secondary 機率的 matching 學習更為通用的模型參數==。從 fig.4( c ) 可以看到 DML 模型的在 secondary probabilities 的分佈上較為平滑,除了主機率外還突顯其他重要的 secondary 機率 - Does DML Makes Models More Similar 作者利用 tSNE 畫出兩個 MobileNets 以不同訓練方法的 feature 分佈,可以看到不論是否使用 DML,訓練出來的網路分佈都有所差異,也間接說明為何 DML 模型之間可以互相學習 ![](https://i.imgur.com/aWFecm7.png) ## Reference 1. [Deep Mutual Learning](https://arxiv.org/abs/1706.00384)