# Generative Adversal Network (GAN) 筆記 ## 基本概念 Generative Adversal Network:生成式對抗網路 Network 可大致分為 Generator 和 Discriminator 兩部份 兩者輪流訓練來對抗對方,共同成長與進步。 ### Generator ![Screenshot_20240125_152618](https://hackmd.io/_uploads/BJImatJ5p.png) 除了自訂的 input $x$ 之外, 還在一個簡單的分佈(例:Guassian Distribution)裡面 Sample 數值當作 input。 目的是讓 Model 有一定程度的創造性。 Network 會輸出一個比原本更複雜的 Output,這一個部份是 GAN 中的 Generator ### Discriminator ![Screenshot_20240125_154409](https://hackmd.io/_uploads/H1SUZqk56.png) input 一張圖片,輸出一個「分數」,判別這張圖片是否是被生成的。 ### Adversarial Training ![Screenshot_20240125_155544](https://hackmd.io/_uploads/By6-Eqkqp.png) #### **擬人化的說法:** 類似於演化的概念。 在一個 Training iteration 中,首先 Generator 生成一堆圖片 Discriminator 學習分辨真實的圖片與生成的圖片 下一個 iteration 中,Generator 要學習如何生成新的圖片嘗試騙過 Discriminator Discriminator 也要再繼續學習怎麼分辨真實的圖片與生成的圖片。 #### **更具體的步驟:** Generator 和 Discriminator 都是擁有 Unknown Parameters 的 Network(Function) 並且連接成一個較大的 Network,其中某個 hidden layer 的輸出就是生成的圖片 ![Screenshot_20240125_164005-1](https://hackmd.io/_uploads/r1wq05Jca.png) 在每一個 iteration 中: * Step 1:定住屬於 Generator 部份的參數,只更新 Discriminator 部份的參數,來讓最後輸出的分數越低越好(代表 Dicriminator 知道圖片是生成的),達到訓練 Discriminator。 * Step 2:定住屬於 Discriminator 部份的參數,只更新 Generator 部份的參數,來讓最後輸出的分數越高越好(代表可以生成出無法被 Discriminator 辨認的成果),達到訓練 Generator。 ## 以數學的角度理解 ![Screenshot_20240125_171641](https://hackmd.io/_uploads/rk3ZPj156.png) $P_G$:Generated Data 形成的 Distribution $P_{data}$:Training Data 形成的 Distribution 目標是找出一組參數 $G^*$,使得 $P_G$ 和 $P_{data}$ 越像越好(算出的 Divergence 越小越好) ![Screenshot_20240125_171303](https://hackmd.io/_uploads/r1JnUsk5T.png) 但是在複雜的 Distribution 上幾乎沒有辦法計算 Divergence 所以在 GAN 中,Discriminator 會在 $P_G$ 和 $P_{data}$ 中各別 Sample 出一些 Data 並用以下式子計算分數: ![Screenshot_20240126_092238](https://hackmd.io/_uploads/H1ettKxq6.png) * $E_{y \sim P_{data}}[logD(y)]$:$y$ 取樣自 $P_{data}$(真實圖片),Discriminator計算分數後取 log * $E_{y \sim P_{G}}[logD(y)]$:$y$ 取樣自 $P_{G}$(生成的圖片),Discriminator計算分數後取 log D 為含有未知參數的 Network(function),算出的分數應介於 0~1 之間 而看到真實圖片時,D 算出的分數要越高越好 相反的,看到生成的圖片時,D 算出的分數要越低越好 而 $V(G,D)$ 越大代表 Discriminator 能把分類工作做得很好,如下圖: ![Screenshot_20240126_085339](https://hackmd.io/_uploads/SJg9uFlcp.png) 所以訓練 Discriminator 的目的就是要找到一組參數 $D^*$,使得 $V(G,D)$ 最大: ![Screenshot_20240126_094802](https://hackmd.io/_uploads/HkA8kce96.png) 而從上圖又可以發現,最後找出的 $\underset{D}{max}V(D,G)$ 會與 Divergence (JS Divergence) 相關 * $\underset{D}{max} V(D,G)$ 大(好分類),Divergence 小 * $\underset{D}{max} V(D,G)$ 小(不好分類),Divergence 大 而我們的終極目標是讓生成的圖片與真實圖片 Divergence 越小越好 所以將以上的概念統整成以下的式子: ![Screenshot_20240125_171318](https://hackmd.io/_uploads/ByEn8jy9T.png) 同時我們要找到一組參數 $D^*$,使得 $V(D,G)$ 可以越大越好 也要找到一組參數 $G^*$,讓 $\underset{D}{max}V(D,G)$ 越小越好 即產生這樣的式子,與前述概念中提到的運作方式相符 註:以上 $V(G,D)$ 可以改成其他算法,此處因最初設計成 Binary Classifier,所以用Negative Cross Entropy 的方式。 ## 優化 ### JS Divergence 缺點 $P_G$ 和 $P_{data}$ 在很高維度的時候,幾乎不可能重疊 而用 JS Divergence 算出來只要不重疊,算出來都是 $log2$ ![Screenshot_20240126_110059](https://hackmd.io/_uploads/HkIdxse5T.png) 這導致在 Training 過程中,這樣的數值無法代表任何意義 ### Wasserstein Distance ![Screenshot_20240126_110744](https://hackmd.io/_uploads/SknWzixc6.png) 像是推土機一樣,找方法將 Distribution「移動」成一樣的 其中最小的平均移動距離就是 Wasserstein Distance 比 JS Divergence 好: ![Screenshot_20240126_111133](https://hackmd.io/_uploads/HkNGXig56.png) WGAN 使用 Wesserstein Distance,實際上的式子: ![Screenshot_20240126_111647](https://hackmd.io/_uploads/SJiXEolcT.png) 其中多了一項限制:D 不可以是隨便的 Function,要夠平滑才可以 ## Conditional Generation 訓練會需要預先標記好的 Data ![Screenshot_20240126_113836](https://hackmd.io/_uploads/H1YBKog56.png) 自訂輸入 $x$,產生的 Output 要跟 $x$ 相關 為了避免 Model 直接忽略 $x$ 來產生隨機的圖片,所以 Discriminator 也要加入 $x$ 來評分 ## Cycle GAN 目標:Learning From Unpaired Data 例如:輸入照片、生成動漫人物圖 ![Screenshot_20240126_114811](https://hackmd.io/_uploads/SJLFsogqT.png) 為了要確保最後生成的圖片真的與輸入的照片相關,要加入以下機制: ![Screenshot_20240126_122922](https://hackmd.io/_uploads/ry4aeTx96.png) 有兩個 Generator,一個負責從 X Domain 轉為 Y Domain,另外一個則是相反 除了確保生成圖片與 Domain Y 相近之外,也保證生成的圖片與輸入相關。 ## 判定生成的好壞 ### Diversity 生成的 Data 可能會特別聚集於某個空間,造成多樣性不足: ![Screenshot_20240126_133657](https://hackmd.io/_uploads/ryrbr6l96.png) 可以透過 Classifier 來對生成的圖片進行分類,並進行以下運算: ![Screenshot_20240126_134326](https://hackmd.io/_uploads/B1tFLaeqa.png) 種類分佈的越平均越好