# 第十二章:領域自適應 (Domain Adaptation) 概述 >上課筆記 * 上課影片連結 * ==[**領域自適應 (Domain Adaptation)**](https://youtu.be/Mnk_oUrgppM)== --- ## 引言 領域自適應 (Domain Adaptation) 是遷移學習 (Transfer Learning) 的一種,旨在解決**訓練資料** (來源域,Source Domain)與**測試資料** (目標域,Target Domain)**分布不同的問題**。 --- ## 領域偏移 (Domain Shift) 當訓練資料和測試資料的**分布存在差異**時,在訓練資料上表現良好的模型,應用於測試資料時性能可能會大幅下降,這種現象稱為領域偏移 (Domain Shift)。 * **[範例](http://proceedings.mlr.press/v37/ganin15.pdf)**:在黑白的 MNIST 手寫數字資料集上訓練的模型,準確率可達 99.5%。但若直接應用於彩色的 MNIST-M 資料集 (目標域),準確率可能驟降至 57.5%。儘管數字形狀相似,但顏色等特徵的分布差異會影響模型判斷。![image](https://hackmd.io/_uploads/HyCgODcTkl.png) * **普遍性**:在許多基準資料集 (Benchmark Corpus) 中,通常假設訓練與測試資料分布相同,這可能導致對模型泛化能力的過度樂觀。真實應用場景中,領域偏移是常見的挑戰。 ### 領域偏移的類型: 1. **輸入分布變化 (Covariate Shift)**:訓練和測試資料的輸入**特徵**分布不同 (例如,黑白 vs 彩色圖片)。==這是本課程主要關注的類型==。![image](https://hackmd.io/_uploads/r1nFOD5aJe.png) 2. **輸出分布變化 (Prior Probability Shift)**:訓練和測試資料的**標籤**分布不同 (例如,訓練集中各數字機率均等,測試集中某些數字出現機率更高)。![image](https://hackmd.io/_uploads/B1q2uDcTkl.png) 3. **條件分布變化 (Concept Shift)**:輸入與輸出**的關係發生變化** (例如,同一種圖像在訓練集中標記為 0,在測試集中標記為 1)。![image](https://hackmd.io/_uploads/HJl0OD5ayx.png) --- ## 領域自適應的情境 領域自適應的目標是利用 Source Domain 的有標註資料,以及對 Target Domain 的部分了解,來訓練一個在 Target Domain 上表現良好的模型。根據對 Target Domain的了解程度,有不同的處理方法: 1. **Target Domain 有少量標註資料 (Little but Labeled)** * **方法**:使用 Source Domain 資料預訓練模型,再利用Target Domain的少量標註資料進行微調 (Fine-tune)。這類似於 [BERT](https://hackmd.io/@Jaychao2099/imrobot8) 等預訓練模型的微調過程,通常只需跑少量 Epochs。 * **挑戰**:Target Domain資料量少,容易過擬合 (Overfitting)。 * **對策**: * 降低學習率 (Learning Rate)。 * 限制微調前後模型參數或輸入輸出關係的差異。 * 其他避免過擬合的技術。 ![image](https://hackmd.io/_uploads/ryEz5DqTkl.png) 2. **Target Domain有大量無標註資料 (Large Amount of Unlabeled Data)** * **方法**:這是更常見且更具挑戰性的情境,也是本課程及作業的重點。目標是利用這些無標註的Target Domain資料,來幫助模型適應Target Domain。 * **核心思想**:訓練一個特徵提取器 (Feature Extractor),使其能夠忽略領域間的差異 (如顏色),提取出共通的、具有判別性的特徵。**使得無論是 Source Domain 還是 Target Domain 的資料,經過特徵提取器後,其特徵分布相似**。 ![image](https://hackmd.io/_uploads/HycwqvcpJl.png) ![image](https://hackmd.io/_uploads/HyHRqDcTJx.png) --- ## 特徵提取器 與 領域對抗訓練 (Domain Adversarial Training) ### 基本架構: 一個典型的分類器可以看作由兩部分組成: * **特徵提取器 (Feature Extractor, $f$)**:一個網路 (例如 CNN 的前幾層),輸入圖像,輸出一組特徵向量 (Feature Vector)。其參數以 $\theta_f$ 表示。 * **標籤預測器 (Label Predictor, $p$)**:一個網路 (例如 CNN 的後幾層),輸入特徵向量,輸出分類結果。其參數以 $\theta_p$ 表示。 ### 訓練目標: * **來源域資料 (Source Domain, Labeled)**:通過特徵提取器和標籤預測器,使其能正確分類。目標是最小化分類損失 $L$,例如交叉熵 (Cross Entropy)。 $$ \min_{\theta_f, \theta_p} L(\theta_f, \theta_p) \text{ on Source Data} $$ * **目標域資料 (Target Domain, Unlabeled)**:雖然沒有標籤無法直接用於訓練標籤預測器,但其通過特徵提取器產生的特徵,應與來源域資料的特徵分布盡可能相似,無法區分。 ![image](https://hackmd.io/_uploads/ry9cjP9ayx.png) ### 領域對抗訓練 (Domain Adversarial Neural Network, DANN): 引入一個**領域分類器 (Domain Classifier, $d$)** 來實現特徵分布的對齊。 * **Domain Classifier ($d$)**:一個二元分類器,輸入特徵提取器輸出的特徵向量,判斷該特徵來自來源域還是目標域。其參數以 $\theta_d$ 表示。 * **訓練過程 (Adversarial Training)**: 1. **訓練 Domain Classifier $d$ 和 Label Predictor $p$**:最小化領域分類損失 $L_d$ (區分來源域和目標域特徵)和標籤預測損失 $L$ (在來源域上正確分類)。$$\theta_p^* = \min_{\theta_p} L(\theta_f, \theta_p) $$$$\theta_d^* = \min_{\theta_d} L_d(\theta_f, \theta_d)$$ 2. **訓練 Feature Extractor $f$ **:目標是**最大化** Domain Classifier 的損失 $L_d$,即**欺騙 Domain Classifier**,使其無法分辨特徵來源;同時也要**最小化** Label Predictor 的損失 $L$,以保證提取的特徵對於分類任務是有用的。$$\theta_f^* =\min_{\theta_f} (L - \lambda L_d)$$$$\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ = \min_{\theta_f} (L(\theta_f, \theta_p) - \lambda L_d(\theta_f, \theta_d))$$其中 $\lambda$ 是超參數,控制對抗損失的權重。原始論文提出的是最小化 $L - L_d$,但這可能導致特徵提取器僅僅是將領域標籤反轉,而非真正混合分布。更好的做法是追求讓 $d$ 完全無法判斷,例如最大化 $L_d$。 ![image](https://hackmd.io/_uploads/SJ9fAv96Jx.png) * **與生成對抗網路 ([GAN](https://hackmd.io/@Jaychao2099/imrobot7)) 的類比**: * Feature Extractor $\approx$ 生成器 (Generator):產生讓判別器混淆的特徵。 * Domain Classifier $\approx$ 判別器 (Discriminator):區分真實 (來源域)和生成 (目標域)的特徵。 * **防止平凡解 (Trivial Solution)**:Feature Extractor 不能簡單地輸出零向量來欺騙領域分類器,因為它還必須產生對 Label Predictor 有用的特徵來最小化 $L$。 ### DANN 的效果: [實驗](https://arxiv.org/abs/1409.7495)表明,DANN 能顯著提升模型在目標域上的性能。例如,MNIST 到 MNIST-M 的準確率從 57.5% 提升到 81%。 ![image](https://hackmd.io/_uploads/S1dbyOcpJl.png) --- ## DANN 的侷限與改進 ### 問題:決策邊界對齊 (Decision Boundary Alignment) DANN 旨在對齊 Source Domain 和 Target Domain 的整體特徵分布,但可能導致 Target Domain 的樣本落在 Source Domain 學習到的決策邊界附近或錯誤的一側。理想情況下,Target Domain 樣本不僅應與 Source Domain 樣本**分布混合**,還應**遠離不同類別間的決策邊界**。 ![image](https://hackmd.io/_uploads/ryG4xd5aJx.png) ### 改進思路:考慮決策邊界 * **目標**:讓無標註的 Target Domain 樣本在通過標籤預測器後,其預測結果置信度高 (即遠離決策邊界)。![image](https://hackmd.io/_uploads/HJBuxdq6ke.png) * **方法示例**: * **最小化預測熵 (Entropy Minimization)**:對於 Target Domain 樣本 $x_t$,希望其預測的類別分布 $p(y | x_t)$ 的熵 $H(p(y | x_t))$ 盡可能小,表示預測集中在某個類別上。 * **[DIRT-T](https://arxiv.org/abs/1802.08735) (Decision-boundary Iterative Refinement Training with a Teacher)**:一種具體的實現方法。 * **[Maximum Classifier Discrepancy](https://arxiv.org/abs/1712.02560)**:另一種鼓勵 Target Domain 樣本遠離決策邊界的方法。 --- ## 更多進階情境與展望 1. **通用領域自適應 ([Universal Domain Adaptation](https://openaccess.thecvf.com/content_CVPR_2019/html/You_Universal_Domain_Adaptation_CVPR_2019_paper.html))** * **問題**:Source Domain 和 Target Domain 的**類別集合可能不完全相同** (可能各有獨有類別,或僅部分重疊)。強行對齊所有特徵可能導致錯誤。 * **目標**:在類別集合可能不同的情況下進行領域自適應。 ![image](https://hackmd.io/_uploads/S16UZ_9pyg.png) 2. **Target Domain 資料稀少且無標註 (Little & Unlabeled)** * **問題**:當 Target Domain 資料極少 (例如只有一張圖像)且無標註時,無法有效估計其分布,DANN 等方法失效。 * **方法示例**:測試時訓練 ([Testing Time Training, TTT](https://arxiv.org/abs/1909.13231)),在測試階段利用單個或少量目標樣本進行模型調整。 ![image](https://hackmd.io/_uploads/HyZgz_9aJl.png) 3. **未知 Target Domain ([Domain Generalization](https://ieeexplore.ieee.org/document/8578664))** * **問題**:對 Target Domain 一無所知,希望模型能泛化到任何未知的領域。![image](https://hackmd.io/_uploads/H1CfGucpJl.png) * **情境一:訓練資料包含多個領域**:利用多樣化的訓練領域學習領域不變的特徵。 * **情境二:訓練資料僅單一領域**:嘗試通過數據增強 (Data Augmentation) 等方法模擬多領域數據,再進行訓練。 ![image](https://hackmd.io/_uploads/ByZwzdqp1x.png) ![image](https://hackmd.io/_uploads/SJldGuq6Jg.png) --- ## 總結 領域自適應是處理**訓練與測試數據分布差異**的重要技術。當 Target Domain 有大量無標註資料時,領域對抗訓練 (DANN) 透過引入領域分類器,使特徵提取器學習領域不變的特徵。然而 DANN 可能有決策邊界對齊問題,需要進一步考慮讓目標樣本遠離邊界。此外,針對類別不匹配、目標數據稀少、 Target Domain 未知等更複雜情境,也有相應的研究方向如通用領域自適應、測試時訓練和領域泛化。 --- 回[主目錄](https://hackmd.io/@Jaychao2099/aitothemoon/)