--- tags: Knock Knock! Deep Learning --- Day 20 / DL x CV / 改變世界的 GAN === 大家幾年前一定看過很多人轉傳,非常逼真但不存在的人臉圖: ![non-existing faces](https://i.imgur.com/gT0MDmY.jpg) *—— 不存在的人臉。[1]* 也一定看到很多人在玩 FaceApp 或一些軟體,能將性別轉換、變老變年輕: ![faceapp](https://i.imgur.com/Nju57xJ.jpg) *—— FaceApp 將性別轉換。[2]* 這些技術背後的 ML 架構正是今天的主角 —— Generative Adversarial Nets (GAN)。讓我們從 generative models 說起,以及 GAN 為什麼能有讓人驚豔的效果,還同時啟發了無數變形架構且適用在非常廣泛的場合。 > 接下來要介紹的東西,很多要深入理解的話,背後都有複雜的數學。我們會盡量介紹到能理解應用即可。 ## Generative Models 有別於先前介紹的 model 多半只有單一預測目標做 supervised learning,generative models 旨在學習 output 多種答案,例如"人臉"可以有無限多種可能。 那這種 generative model 具體來說的學習目標是什麼?要怎麼訓練呢? ### Maximum Likelihood Estimation (MLE) 其中一種常見做法是將 model 目標設為模擬真實 data 的 probability distribution $p_{\text{data}}$,並用 **Maximum Likelihood Estimation (MLE)** 找出 parameters $\theta$ 來讓 training data 的 likelihood 最大: $$ \theta^* = {\arg \max}_{\theta} \Pi^m_{i=1} p_{\theta}(x^{(i)}) $$ $p_{\theta}$ 是 model 以 parameter $\theta$ 模擬的 data distribution,$m$ 為 training data 的大小,${\arg \max}_{\theta}$ 是能讓後面那項最大的 $\theta$。 因為我們的目標 $p_{\text{data}}$ 是真實世界所有可能 data 的分佈,但我們只會有其中一小群 $m$ 個 sample,所以我們想模擬的其實是這 $m$ 個 sample 的分佈 $\hat{p}_{\text{data}}$。數學上,上面 maximize likelihood 其實會等於 minimize $\hat{p}_{\text{data}}$ 和 $p_{\theta}$ 的分佈差距 **KL divergence**。 ### 訓練方法 知道 model 目標後,接下來可以分類成很多不同的訓練方法: ![taxonomy of deep generative models](https://i.imgur.com/Zmo6jxk.png) *—— Deep generative models 分類樹狀圖。[4]* MLE 下來分兩類:**explicit density** 和 **implicit density**。Explicit density 中,會把 density function $p_{\theta}(x)$ 的實際 format 直接放在訓練中,並根據上面 MLE 的式子做 optimization。 例如在 **PixelRNN** 中,density function 會被分解成一連串的步驟: $$ p_{\theta}(x) = \Pi^m_{i=1} p_{\theta}(x_i | x_1, \dots, x_{i-1}) $$ 這個形式是 [tractable density](https://stats.stackexchange.com/questions/4417/what-are-the-factors-that-cause-the-posterior-distributions-to-be-intractable),也就是能用 [closed-form](https://en.wikipedia.org/wiki/Closed-form_expression) 的形式表示,運算上比較不複雜。但缺點是要一步一步跑,算完一串步驟運算太花時間。 另一方面在 **Variational Autoencoder (VAE)** 中,用了不一樣的方式分解 density function: $$ p_{\theta}(x) = \int p_{\theta}(z) p_{\theta}(x|z) dz $$ 一個 autoencoder 做的事是把 input $x$ encode 成 latent feature $z$,目標是讓 decoder 能根據 $z$ decode 成原本的 input $x$,並用 $x$ 和 decode 出來的 $\hat{x}$ 的差距當作 loss 訓練。而 VAE 因為是 generative model,要找出的是這個 latent feature $z$ 的 probability distribution $p(z)$ 而非單一值。 > 可以參考 [5] 的詳細解釋。 式子中 $p_{\theta}(x|z)$ 根據 $z$ 預測 $x$ 是 decoder,乘上 prior $p_{\theta}(z)$,這部分都算 tractable。但 integration 要對所有 $z$ 做這件事就不太實際了,因此整體來說屬於 intractable density,需要用 approximate 的方式讓他變 tractable。 簡單來說我們如果在 encoder 部分用 $q_{\phi}(z|x)$ approximate $p_{\theta}(z|x)$,那麼經過複雜數學推導 [3, p.61-70],會得到 maximize $p_{\theta}(x)$ 等於 maximize 下面這個 tractible 的式子: $$ \mathbf{E}_z [\log p_{\theta}(x|z)] - D_{KL}(q_{\phi}(z|x)\|p_{\theta}(z)) $$ 也就是訓練一個 autoencoder 讓 encoder 出來的分佈 $q_{\phi}(z|x)$ 接近 prior $p_{\theta}(z)$,並讓 decoder 出來的結果是 input $x$ 的 likelihood $\mathbf{E}_z [\log p_{\theta}(x|z)]$ 越大越好。 雖然 VAE 用 approximation 方式解決了 PixelRNN 運算慢的問題,但實際 generate 出來的結果還是成效不佳,圖像相當模糊: ![vae output](https://i.imgur.com/v5thvdp.jpg) *—— VAE 圖像生成結果。[3]* 因為 GAN 是今天要介紹的重點,上面兩個方法就簡單介紹到這邊,有興趣可以去延伸閱讀細看 paper。接下來我們就來看一下 GAN 如何不明確使用 density function (implicity density),也能找到方法生成好的圖像。 ## Generative Adversarial Nets (GAN) 上面我們簡單介紹了兩種從 density function 出發,學習 data 的分佈並生成圖像的方法。而本篇的主角 —— **Generative Adversarial Nets (GAN)**,則是選擇繞過找出明確 density function 這條路,直接學習怎麼把一個 random noise 轉換成圖像。結果來說,GAN 的成效打破了其他 generative model 的 performance,也因為他有趣有彈性的架構,從提出以來一直是學術界很熱門的研究主題,也能找到很多有趣的應用。 讓我們先從架構和訓練概念介紹起吧! ### Framework 前面可以看到直接企圖找出 training data 的 distribution 用來 sample 新的 data 實在太難了。GAN 決定換個方向走:**隨意的產生 random noise,並透過 neural network 學習把 random noise 轉換成真的 data 的方法**。 在 GAN 中,我們會先定義一個 **generator $G$**,把 random noise $z$ 轉換成 output $G(z)$。但要怎麼知道生成結果 $G(z)$ 是不是真的可以假裝是從原本的 data distribution sample 出來的呢? 為了提供訓練的 signal,GAN 中另外定義了一個 **discriminator $D$**,學習怎麼區分好與壞的結果。也就是把 $G(z)$ 丟進 $D$ 裡,output 出好與壞兩種結果,或更好一點,output 介於 $0$ 和 $1$ 的分數:越靠近 $0$,代表 discriminator 覺得生成結果是假的;越靠近 $1$,代表他覺得結果是真的。而在訓練 $D$ 時,我們會把 $G(z)$ 假結果跟真的 data $x$ 都丟給 $D$,讓他學習判別。 **Generator 訓練目標是讓 discriminator 覺得自己的生成結果是真的,也就是盡量讓 $D(G(z))$ 靠近 $1$。而 discriminator 訓練目標是駁回 generator 的結果,也就是盡量讓 $D(G(z))$ 靠近 $0$**。透過這樣的 two-player game 互相學習之後,generator 的生成結果越來越能欺騙 discriminator,也就是越來越接近真實的 data 了。 ![GAN framework](https://i.imgur.com/vVMAiDI.png) *—— GAN framework。$G$ 和 $D$ 透過 two-player game 互相學習。[4]* 上圖為 two-player game 的架構。為了訓練 $D$ 區分真的和假的結果,會給他真實 data(左)和 generator 的假結果(右)學習。而 $G$ 會從 $D$ 給的分數 $D(G(z))$ 學習。 ### Training Objective 這樣的架構形成了一個 [minimax game](https://en.wikipedia.org/wiki/Minimax),訓練目標是讓 $D$ 學習 parameters $\theta_d$ 來 maximize $D(x)$ 靠近 $1$ 且 $D(G(z))$ 靠近 $0$ 的機會: $$ \max_{\theta_d} [\mathbb{E}_{x \sim p_{\text{data}}} \log D(x) + \mathbb{E}_{z \sim p_(z)} \log (1 - D(G(z)))] $$ > 右項指 $1 - D(G(z))$ 越靠近 $1$ 越好,意同 $D(G(z))$ 越靠近 $0$ 越好。 反之要讓 $G$ 學習 parameters $\theta_g$ 來 minimize $D(G(z))$ 靠近 $0$ 的機會: $$ \min_{\theta_g} \mathbb{E}_{z \sim p(z)} \log (1 - D(G(z))) $$ 這樣 $G$ 和 $D$ 的訓練目標就定義完成了! > 因為 log 會讓 $D(G(z))$ 在靠近 $0$ 的時候,$\log (1 - D(G(z)))$ 的 gradient 太小而難以學習,所以實際上會以 $\max_{\theta_g} -\mathbb{E}_{z \sim p(z)} \log (1 - D(G(z))) = \max_{\theta_g} \log D(G(z))$ 來訓練 $G$。Minimize 某式 = maximize 某式的負值! ### Training Procedure 整個架構略顯龐大,且需要兩個 network 交互訓練。作者很貼心的提供了 pseudo-code 來讓訓練過程更清楚: ![GAN trining procedure](https://i.imgur.com/LN5nZEx.png) *—— GAN 訓練框架。[6]* 每個 iteration 中,我們先訓練 discriminator $k$ 次。每次提供 $m$ 個真實 data $x$ 和假的 $G(z)$ 做 supervised learning。讓 discriminator 先訓練,是為了提供 generator 有用的分數反饋。 接著再訓練 generator,讓生成結果 $G(z)$ 給 discriminator 打分數並修正自己。 如此循環幾個 iteration。 成功訓練的話,兩邊的理想訓練趨勢大概是這樣: ![GAN alternate training](https://i.imgur.com/hpu1oxM.png) *—— GAN 交互訓練理想圖。綠線:generator 生成分佈。黑點線:真實 data 分佈。藍點線:discriminator 區分界線。黑箭頭:generator 將 random noise 轉換成 data。[6]* 一開始 $(a)$ 中 discriminator 大致能把真實 data 和假結果區隔開來。經過訓練後,$(b)$ 中 discriminator 學會了更嚴謹的判別。接著 $(c)$ 中訓練 generator,藉由 discriminator 的反饋,慢慢往真實分佈調整。最後達到 $(d)$ 之後,discriminator 就再也沒有可靠訊息有效區分兩者,generator 學會生成擬真的結果。 ### Results 最後來看一下 GAN 的生成結果。 ![original GAN output](https://i.imgur.com/OjLx6PN.png) *—— GAN 生成四種 dataset 的圖像。最右邊是 training set 中最接近的 data,可以證明 GAN 學到的不是直接照抄。[6]* 原始 GAN paper 提供的結果。Paper 比較偏理論,圖像算是實驗性質,有些模糊,但不難看出 GAN 的一些潛力。 之後很多 GAN 的變形,能生成越來越好的圖像。最簡單從改用 CNN 開始的 DCGAN: ![DCGAN output](https://i.imgur.com/kYP6kVY.jpg) *—— DCGAN 生成假房間照片。[7]* 甚至能把學到的 representation 拿來玩加減: ![DCGAN vector math](https://i.imgur.com/eatqV55.png) *—— DCGAN representation 做加減:戴墨鏡的男子 - 男子 + 女子 = 戴墨鏡的女子。[7]* ## The GAN Zoo 從 GAN 有趣的 framework 提出以後,學術界中關於 GAN 的發表開始爆炸,越來越多有的沒的 GAN 變形被提出,收羅成 [The GAN Zoo](https://github.com/hindupuravinash/the-gan-zoo)。 稍微改造一下 GAN,就能做很多應用:給圖像框架提供著色、風格轉換、提升解析度、生成圖像中被遮住的部分等等。而 GAN 的概念不只能被應用在 CV 領域中,例如做 NLP 生成還能選擇 SeqGAN! ## 結語 GAN 有趣又有效的架構啟發了無數 CV 界的應用,堪稱經典。下一篇要來介紹我在 CS231n 用 CycleGAN 做的 project:字型風格轉換。也是我做過最滿意的 project! ## Checkpoint - Generative model 和一般 supervised learning 任務,在目標上有什麼根本差異,讓 generative model 的訓練特別困難? - GAN 和前面介紹的 PixelRNN 和 VAE,在目標上有什麼不同? - GAN 中為什麼需要額外訓練一個 discriminator? - Discriminator input 和 output 是什麼?訓練目標為何? - Generator input 和 output 是什麼?訓練目標為何? - 為什麼訓練步驟中要先讓 discriminator 訓練? - 當 $D$ 的 output 大致為什麼值的時候,我們會說 discriminator 已經沒有辨別能力了? ## 參考資料 1. [(Karras et al., 2018) A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/pdf/1812.04948.pdf) 2. [The AI Behind FaceApp](https://analyticsindiamag.com/the-ai-behind-faceapp/) 3. [CS231n Lecture Slides: Generative Models](http://cs231n.stanford.edu/slides/2020/lecture_11.pdf) 4. [👍 NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/pdf/1701.00160.pdf) 5. [👍 Variational autoencoders.](https://www.jeremyjordan.me/variational-autoencoders/) 6. [(Goodfellow et al., 2014) Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf) 7. [(Radford et al., 2016) Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434.pdf) ## 延伸閱讀 1. [(Oord et al., 2016) Pixel Recurrent Neural Networks](https://arxiv.org/pdf/1601.06759.pdf) 2. [(Kingma et al., 2014) Auto-Encoding Variational Bayes](https://arxiv.org/pdf/1312.6114.pdf) 6. [(Yu et al., 2016) SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/pdf/1609.05473.pdf)