# 第十六章:元學習 (Meta Learning) >上課筆記 * 上課影片連結 * ==[**元學習(一) - 元學習跟機器學習一樣也是三個步驟**](https://youtu.be/xoastiYx9JU)== * ==[**元學習(二) - 萬物皆可 Meta**](https://youtu.be/Q68Eh-wm1Ts)== --- ## 什麼是元學習? 元學習的核心概念是「學會如何學習 (Learn to learn)」。 一般而言,「Meta-X」指的是「關於 X 的 X (X about X)」。因此,Meta Learning 就是「關於學習的學習」。 --- ## 動機 在機器學習領域,特別是在業界,常常需要耗費大量資源 (如使用 1000 個 GPU) 嘗試不同的超參數組合,如同大海撈針。而在學術界,研究者則往往依賴經驗或直覺 (被戲稱為「通靈」) 來尋找好的超參數。 一個重要的問題是:我們能否讓機器自動地學習如何決定這些超參數,甚至自動地學習整個學習過程本身?這就是元學習試圖解決的問題。 --- ## 機器學習回顧 ([第一章](https://hackmd.io/@Jaychao2099/imrobot1)) 傳統的機器學習 (監督學習) 可以看作是尋找一個未知的目標函數 $f$。其過程主要包含三個步驟: 1. **定義包含未知參數的函數 (Function with unknown)**: * 選擇一個模型架構,例如神經網路。這個網路的權重 (weights) 和偏差 (biases) 是未知的,我們用 $\theta$ 來表示這些所有可學習的參數。 * 模型可以表示為 $f_\theta(x)$,輸入 $x$ (例如一張圖片),輸出預測結果 (例如 "貓" 或 "狗" 的機率分佈)。![image](https://hackmd.io/_uploads/H1RAbL0Ckl.png) 2. **定義損失函數 (Define loss function)**: * 損失函數 $L(\theta)$ 用來衡量模型參數 $\theta$ 的好壞。 * 對於訓練資料中的每一個樣本 $k$,計算模型輸出 $f_\theta(x_k)$ 與真實標籤 $y_k$ (Ground Truth) 之間的差異,得到單一樣本的損失 $e_k$ (例如使用 Cross-entropy)。 * 總損失通常是所有訓練樣本損失的總和或平均:$L(\theta) = \sum_{k=1}^{K} e_k$。![image](https://hackmd.io/_uploads/ry0PMLC0Jx.png) 3. **優化 (Optimization)**: * 目標是找到一組參數 $\theta^*$,使得損失函數 $L(\theta)$ 最小化:$\theta^* = arg\min_{\theta} L(\theta)$。 * 常用的方法是梯度下降法 (Gradient Descent) 及其變種,透過計算損失函數對參數的梯度 $\nabla_\theta L(\theta)$ 來迭代更新參數 $\theta$。 * 最終得到的 $f_{\theta^*}$ 就是透過學習演算法從資料中學習到的函數。 --- ## 元學習:學習如何學習 相較於機器學習旨在從資料中學習一個特定的函數 $f_{\theta^*}$,元學習的目標是學習一個「學習演算法」本身,我們稱這個學習演算法為 $F$。 * **機器學習 (ML)**:輸入是訓練資料,輸出是一個訓練好的模型 $f_{\theta^*}$。學習演算法 (如梯度下降) 通常是人工設計且固定的。 * $f_{\theta^*} = \text{ML_Algorithm}(\text{Training Data})$ * **元學習 (Meta Learning)**:輸入是多個不同的「任務 (Task)」,輸出是一個能夠快速學習新任務的「學習演算法 $F$」。 * $F_{\phi^*} = \text{Meta_Learning_Algorithm}(\text{Training Tasks})$ * 然後 $f_{\theta^*} = F_{\phi^*}(\text{Task Data})$ ![image](https://hackmd.io/_uploads/rygaQ8A0yx.png) 元學習也遵循類似機器學習的三個步驟,但應用的層次不同: ### 步驟一:定義學習演算法 $F_\phi$ 首先,要定義學習演算法 $F$ 中哪些部分是可學習的。這些可學習的元件或參數,我們用 $\phi$ 來表示。 傳統的學習演算法 (如梯度下降) 包含許多元件,其中一些可能可以透過元學習來自動學習,例如: * 網路架構 (Net Architecture) * 初始參數 (Initial Parameters) * 學習率 (Learning Rate) * 優化器規則 (Optimizer rules) * 資料處理或增強策略 (Data processing/augmentation strategy) 根據 $\phi$ 代表的不同元件,可以將元學習方法進行分類。 ![image](https://hackmd.io/_uploads/HySZV80R1x.png) ### 步驟二:定義學習演算法的損失函數 $L(\phi)$ 為了學習 $\phi$,需要定義一個損失函數 $L(\phi)$ 來衡量學習演算法 $F_\phi$ 的好壞。 與機器學習使用單一任務的訓練樣本計算損失不同,元學習需要**多個不同的任務**來進行訓練。這些任務稱為「訓練任務 (Training Tasks)」。 對於每一個訓練任務 $n$ (例如:任務 1 是蘋果和橘子的分類,任務 2 是汽車和腳踏車的分類): 1. **任務內數據劃分**:每個訓練任務 $n$ 的資料會被劃分為: * **訓練集 (Training set / Support set)**:用來給學習演算法 $F_\phi$ 作為輸入,讓它學習該任務。 * **測試集 (Testing set / Query set)**:用來評估 $F_\phi$ 在這個任務上學習到的成果。 2. **任務內學習**:將任務 $n$ 的 Support set 輸入給學習演算法 $F_\phi$,它會輸出一組針對該任務的特定模型參數 $\theta^{n*}$。即 $f_{\theta^{n*}} = F_\phi(\text{Support Set}_n)$。 3. **任務內評估**:使用 $f_{\theta^{n*}}$ 在任務 $n$ 的 Query set 上進行預測,並計算其損失 $l^n$ (例如 Cross-entropy)。這個 $l^n$ 反映了學習演算法 $F_\phi$ 在任務 $n$ 上的學習效果。 4. **元學習損失**:最終的元學習損失 $L(\phi)$ 是所有 N 個訓練任務的損失 $l^n$ 的總和或平均:$L(\phi) = \sum_{n=1}^{N} l^n$。 ![image](https://hackmd.io/_uploads/SkBA4LCCJe.png) **關鍵區別**: * 機器學習的損失 $L(\theta)$ 是在**訓練樣本**上計算的。 * 元學習的損失 $L(\phi)$ 是基於學習演算法在各個訓練任務的**測試樣本 (Query set)** 上的表現來計算的。這並非使用最終要評估元學習演算法的、從未見過的「測試任務」的資料,而是在定義 $L(\phi)$ 的過程中,每個用於訓練 $F_\phi$ 的任務內部,都需要用到該任務自己的測試集 (Query set) 來評估 $F_\phi$ 產生的 $f_{\theta^{n*}}$ 的好壞。 ### 步驟三:優化學習演算法 目標是找到一組元參數 $\phi^*$,使得元學習損失 $L(\phi)$ 最小化:$\phi^* = arg\min_{\phi} L(\phi)$。 * **基於梯度的方法**:如果 $L(\phi)$ 對於 $\phi$ 是可微分的 (這通常需要學習演算法 $F_\phi$ 的內部運算也是可微分的),則可以使用**梯度下降法**來優化 $\phi \leftarrow \phi - \eta \nabla_\phi L(\phi)$。計算 $\nabla_\phi L(\phi)$ 的過程可能比較複雜,因為 $L(\phi)$ 是透過 $F_\phi$ 產生 $f_{\theta^{n*}}$,再用 $f_{\theta^{n*}}$ 在 Query set 上計算 $l^n$ 而來,涉及到巢狀的優化過程。 * **無梯度方法**:如果 $L(\phi)$ 不可微分 (例如 $\phi$ 代表離散的網路架構),則需要使用其他優化方法,如[強化學習 (Reinforcement Learning)](https://hackmd.io/@Jaychao2099/imrobot13) 或[演化演算法 (Evolutionary Algorithm)](https://en.wikipedia.org/wiki/Evolutionary_algorithm)。 完成優化後,得到的 $F_{\phi^*}$ 就是一個學習到的、效果更好的「學習演算法」。 --- ## 元學習框架總結 1. **元訓練階段 (Meta-Training)**: * 準備 N 個不同的訓練任務 $\{Task_1, Task_2, ..., Task_N\}$。 * 每個任務 $Task_n$ 包含 Support Set$_n$ 和 Query Set$_n$。 * 使用這些訓練任務,透過最小化 $L(\phi) = \sum_{n=1}^{N} l^n$ 來學習元參數 $\phi^*$,得到學習演算法 $F_{\phi^*}$。這個過程稱為「**跨任務訓練 (Across-task Training)**」,其中包含了「**任務內訓練與測試 (Within-task Training and Testing)**」。 ![image](https://hackmd.io/_uploads/rkhxvICAyl.png) 2. **元測試階段 (Meta-Testing)**: * 準備一個全新的、在元訓練階段從未見過的測試任務 $Task_{test}$。 * 測試任務也包含 Support Set$_{test}$ 和 Query Set$_{test}$。通常 Support Set$_{test}$ 的樣本量很少 (符合 Few-shot learning 的情境)。 * 將 Support Set$_{test}$ 輸入給學習到的演算法 $F_{\phi^*}$,得到針對測試任務的模型 $f_{\theta_{test}^*}$。即 $f_{\theta_{test}^*} = F_{\phi^*}(\text{Support Set}_{test})$。 * 最後,使用 $f_{\theta_{test}^*}$ 在 Query Set$_{test}$ 上評估最終性能。 ![image](https://hackmd.io/_uploads/r1K78IC0kl.png) --- ## 機器學習 (ML) vs. 元學習 (Meta Learning) | 特徵 | 機器學習 (ML) | 元學習 (Meta Learning) | | ------------ | -------------------------------------------- | ------------------------------------------------------------------ | | **目標** | 學習一個特定任務的函數 $f_\theta$ | 學習一個「學習演算法」$F_\phi$ | | **訓練資料** | 單一任務的大量標註樣本 | 多個不同任務,每個任務資料量可多可少 | | **訓練過程** | 任務內訓練 (Within-task Training) | 跨任務訓練 (Across-task Training),包含任務內的訓練與測試 | | **測試過程** | 任務內測試 (Within-task Testing) | 跨任務測試 (Across-task Testing),在新任務上進行任務內的訓練與測試 | | **損失函數** | 基於訓練樣本計算 ($L(\theta) = \sum e_k$) | 基於訓練任務的測試樣本 (Query set) 計算 ($L(\phi) = \sum l^n$) | | **優化對象** | 模型參數 $\theta$ | 學習演算法的元參數 $\phi$ | | **輸出** | 針對特定任務訓練好的模型 $f_{\theta^*}$ | 一個通用的學習演算法 $F_{\phi^*}$ | | **資料集區分** | 訓練集 / 驗證集 / 測試集 | 訓練任務 (Support/Query Sets) / 測試任務 (Support/Query Sets) | | **過擬合** | 可能在訓練資料上過擬合 | 可能在訓練任務上過擬合 (需要更多訓練任務、任務增強 Task augmentation) | | **其他** | 超參數需人工調整 | 元學習本身也有超參數 (如元學習率),需要開發任務 (Development task) 進行調整 | 元學習的訓練過程通常涉及**內外循環 (Inner/Outer Loop)**: * **內循環 (Inner Loop)**:對於每個訓練任務,使用 $F_\phi$ 和 Support Set 學習特定任務的參數 $\theta^{n*}$ (Within-task Training)。 * **外循環 (Outer Loop)**:使用 $\theta^{n*}$ 在 Query Set 上計算損失 $l^n$,並根據 $L(\phi)=\sum l^n$ 的梯度更新元參數 $\phi$ (Across-task Training)。 --- ## 元學習能學什麼? 元學習的具體實現取決於我們選擇讓哪些元件 $\phi$ 成為可學習的。 ### 學習初始化參數 (Learning to Initialize) 這是元學習中最常見的一類方法。目標是學習**一組好的初始模型參數** $\phi^0$,使得模型從這個初始點出發,只需要在新的、只有少量樣本的任務上進行少量幾步梯度下降,就能快速適應並獲得良好性能。 ![image](https://hackmd.io/_uploads/ryzjxdCCyx.png) * **[MAML](https://arxiv.org/abs/1703.03400) (Model-Agnostic Meta-Learning)**:[[2019上課影片]](https://youtu.be/mxqzGwP_Qys) * **核心思想**:尋找一個初始參數 $\phi$,使得在任何任務上,只要基於該任務的 Support Set 做少量梯度下降,就能在該任務的 Query Set 上獲得很低的損失。 * MAML 的損失函數 $L(\phi)$ 計算的是「在各個任務上進行少量梯度更新後」的模型在 Query Set 上的損失總和。 * MAML 的優化需要計算梯度的梯度 (二階導數),計算量較大。 * **FOMAML (First-Order MAML)**:是 MAML 的簡化版,忽略了二階導數,降低了計算複雜度。[[2019上課影片]](https://youtu.be/3z997JhL9Oo) ![image](https://hackmd.io/_uploads/H1dy3PARkx.png) * **[Reptile](https://arxiv.org/abs/1803.02999)**:[[2019上課影片]](https://youtu.be/9jJe2AD35P8) * 對於每個任務,從當前 $\phi$ 開始,在該任務上執行多步梯度下降得到 $\theta^{n*}$。 * 然後將 $\phi$ 朝著所有任務的 $\theta^{n*}$ 的方向移動一小步 (例如 $\phi \leftarrow \phi + \epsilon \frac{1}{N} \sum (\theta^{n*} - \phi)$)。 ![image](https://hackmd.io/_uploads/B1etKP001e.png) * **與預訓練 (Pre-training) / 遷移學習 (Transfer Learning) / 多任務學習 (Multi-task Learning) 的比較**:[[2019上課影片]](https://youtu.be/vUwOA3SNb_E) * **預訓練/遷移學習**:通常在一個大規模通用資料集 (或相關的源任務) 上訓練模型,然後將學到的參數 (或部分參數) 作為新任務的初始點進行微調 (Fine-tuning)。其目標是學習對源任務/通用數據有用的表示。 ![image](https://hackmd.io/_uploads/rJzf6PCAJl.png) * **多任務學習**:同時學習多個相關任務,通常共享部分模型參數,目標是提升所有任務的性能。 ![image](https://hackmd.io/_uploads/rke1TDC0ke.png) * **MAML**:目標是學習一個「容易適應新任務」的初始參數,而非在源任務上達到最佳性能。它明確地針對「快速適應少量樣本」進行優化。 ![image](https://hackmd.io/_uploads/rkzBavRRJe.png) * 研究 (如 [ANIL](https://arxiv.org/abs/1909.09157)) 指出,MAML 的成功可能很大程度上來自於**特徵重用 (Feature Reuse)**,即學習到的 $\phi$ 本身就是一個很好的特徵提取器,只需要微調最後的分類層即可快速適應。 ![image](https://hackmd.io/_uploads/ryNIqv0A1l.png) ### 學習優化器 (Learning the Optimizer) 傳統優化器 (如 Adam, RMSprop, 詳見[第三章](https://hackmd.io/@Jaychao2099/imrobot3#%E5%8F%83%E6%95%B8%E7%9B%B8%E4%BE%9D%E7%9A%84%E5%AD%B8%E7%BF%92%E9%80%9F%E7%8E%87-Parameter-Dependent-Learning-Rate)) 有固定的更新規則和需要調整的超參數 (如學習率)。元學習可以**學習優化器本身**。![image](https://hackmd.io/_uploads/rkjNyO0RJl.png) * 可以學習優化器的**超參數** (如學習率 $\lambda$)。 * 甚至可以學習整個**參數更新規則**。例如,使用一個 RNN 模型來代替固定的更新公式,RNN 的輸入是參數的梯度,輸出是參數的更新量。這個 RNN 的參數 $\phi$ 可以透過元學習訓練得到。 * [實驗](https://arxiv.org/abs/1606.04474)表明,學習到的優化器在某些特定任務上可能比人工設計的通用優化器效果更好。 ![image](https://hackmd.io/_uploads/SJj8luRAyl.png) ### 網路架構搜索 (Network Architecture Search - NAS) 讓機器自動學習最佳的網路架構 $\phi$。 ![image](https://hackmd.io/_uploads/BJsyZdCC1l.png) * **基於強化學習 ([RL](https://hackmd.io/@Jaychao2099/imrobot13))**:用一個控制器 (通常是 RNN) 生成網路架構的描述。將這個架構在目標任務上訓練、評估其性能 (例如準確率),將性能作為獎勵 (Reward) 回傳給控制器,訓練控制器使其能生成更高性能的架構。這個過程的 $L(\phi)$ 是負的 Reward,不可微,需要用 RL 方法優化控制器參數 $\phi$。 * 相關文獻:[[1]](https://arxiv.org/abs/1611.01578)[[2]](https://arxiv.org/abs/1707.07012)[[3]](https://arxiv.org/abs/1802.03268) ![image](https://hackmd.io/_uploads/BJaNZuA01l.png) * **基於演化演算法 (EA)**:維護一個網路架構的群體 (Population),透過突變 (Mutation) 和交叉 (Crossover) 操作產生新的架構,根據性能進行選擇,不斷進化出更好的架構。 * 相關文獻:[[1]](https://arxiv.org/abs/1703.01041)[[2]](https://arxiv.org/abs/1802.01548)[[3]](https://arxiv.org/abs/1711.00436) * **基於梯度的方法 ([Differentiable Architecture Search - DARTS](https://arxiv.org/abs/1806.09055))**:將離散的架構選擇鬆弛 (Relax) 為連續的權重組合 (可微分),使得架構參數 $\phi$ 可以和模型參數 $\theta$ 一起**透過梯度下降**進行優化,效率較高。 ![image](https://hackmd.io/_uploads/B1sF-u0C1g.png) ### 學習資料處理/增強/重加權策略 (Data Processing) ![image](https://hackmd.io/_uploads/SJlHQuACkx.png) 可以學習如何自動地: * **數據增強 (Data Augmentation)**:學習一系列數據增強操作的最優組合策略。 * 相關文獻:[[1]](https://arxiv.org/abs/2003.03780)[[2]](https://arxiv.org/abs/1905.05393)[[3]](https://arxiv.org/abs/1805.09501) ![image](https://hackmd.io/_uploads/Hk_wXOCRJl.png) * **樣本重加權 (Sample Reweighting)**:為不同的訓練樣本賦予不同的權重,例如給難樣本更高的權重,或給可能有噪聲標籤的樣本較低的權重。學習一個函數來決定這些權重 (如 Meta-Weight-Net)。 * 相關文獻:[[1]](https://arxiv.org/abs/1902.07379)[[2]](https://arxiv.org/abs/1803.09050) ![image](https://hackmd.io/_uploads/Bk_97_CAyl.png) ### 超越梯度下降 (Beyond Gradient Descent) 元學習甚至可以嘗試發明全新的學習演算法,完全不依賴梯度下降。例如,[Meta-Learning with Latent Embedding Optimization](https://arxiv.org/abs/1807.05960) 提出了一種方法,將模型參數嵌入到一個潛在空間中,並在這個空間中進行優化。 ![image](https://hackmd.io/_uploads/S1mF4dC0yl.png) ### 學習比較 (Learning to Compare / Metric-based) 這類方法通常**不顯式地區分**學習演算法 $F$ 和最終分類器 $f$。而是直接學習一個函數 (或度量空間),使得來自同一類別的樣本在該空間中距離近,不同類別的樣本距離遠。在預測時,將新樣本與 Support Set 中的樣本進行比較來判斷其類別。 ![image](https://hackmd.io/_uploads/By6bBd0Cyg.png) * 2019上課影片:[[1]](https://youtu.be/yyKaACh_j3M)[[2]](https://youtu.be/scK2EIT7klw)[[3]](https://youtu.be/semSxPP2Yzg)[[4]](https://youtu.be/ePimv_k-H24) --- ## 應用 ### 少樣本圖像分類 (Few-shot Image Classification) 這是元學習最經典的應用場景。目標是在每個類別只有極少量 (例如 1 個或 5 個) 標註樣本的情況下進行分類。 * **N-ways K-shot**:指每個任務包含 N 個類別,每個類別有 K 個標註樣本 (Support Set)。 ![image](https://hackmd.io/_uploads/Sy-ardR0yl.png) * **[Omniglot](https://github.com/brendenlake/omniglot) 數據集**:一個常用於 Few-shot Learning 的手寫字符數據集,包含來自 50 種不同字母系統的 1623 種字符,每個字符有 20 個樣本。 ![image](https://hackmd.io/_uploads/S1JZL_0Cyx.png) * **元學習流程**: 1. 將所有字符 (類別) 劃分為元訓練集和元測試集。 2. 在元訓練階段,反覆從元訓練集中抽樣 N 個字符,每個字符抽取 K 個樣本作為 Support Set,再抽取一些樣本作為 Query Set,構成一個訓練任務,用來更新 $F_\phi$。 3. 在元測試階段,從元測試集中抽樣 N 個字符,每個字符抽取 K 個樣本作為 Support Set,再抽取一些樣本作為 Query Set,構成一個測試任務。使用學好的 $F_{\phi^*}$ 和 Support Set 得到 $f_{\theta^*}$,然後在 Query Set 上評估性能。 ![image](https://hackmd.io/_uploads/r1vmLdAAke.png) ### 其他應用領域 元學習已被應用於多種機器學習問題,特別是那些**數據稀疏**或**需要快速適應新環境**的場景: * 聲音事件檢測 (Sound Event Detection) * 關鍵字識別 (Keyword Spotting) * 文本分類 (Text Classification) * 語音克隆 (Voice Cloning) * 序列標註 (Sequence Labeling) * 機器翻譯 (Machine Translation) * 語音辨識 (Speech Recognition) * 知識圖譜 (Knowledge Graph) * 對話系統 / 聊天機器人 (Dialogue / Chatbot) * 句法分析 (Parsing) * 詞嵌入 (Word Embedding) * 多模態學習 (Multi-model) 這些應用通常利用元學習來實現快速適應 (如 MAML) 、學習比較 (Metric-based) 或自動化 (如 NAS、學習優化器) 等目標。 --- 回[主目錄](https://hackmd.io/@Jaychao2099/aitothemoon/)