owned this note
owned this note
Published
Linked with GitHub
# How 2022 became the year of generative AI? Diffusion !
> 我深刻地體會到生成模型除了數學,還有數學,跟數學
![](https://i.imgur.com/1JVNZJh.png)
and then August 2022, **Stable Diffusion** release
# A New Generative model
![](https://i.imgur.com/qCBDlgI.png)
# **Diffusion: From data to noise and back**
2015 年一篇由 Sohl-Dickstein 所發表的 **[Deep Unsupervised Learning using Nonequilibrium Thermodynamics](https://arxiv.org/abs/1503.03585)** 開創了 Diffusion model,而作者提到這是受到 equilibrium statistical physics 領域啟發,而我完全不知道這是甚麼。在這篇論文中作者展示了如何使用 diffusion model 來產生 CIFAR-10 的資料,效果非常普通而且論文充滿數學。
> A diffusion probabilistic model is a **parameterized Markov chain** trained using
variational inference to produce samples matching the data after finite time. - ****[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)****
2020 年,**[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)** (DDPM) 使用參數化方法簡化了整個 diffusion model 的學習過程並透過 variational inference 來進行建模,建立了一個 latent vector 與輸入圖片同維度的模型,就此樹立了 diffusion model 的發展基石,這篇論文中所建構出來的訓練步驟如下:
1. Sample random noise to be added to the inputs.
2. Apply the forward process to diffuse the inputs with the sampled noise.
3. Your model takes these noisy samples as inputs and outputs the noise prediction for each time step.
4. Given true noise and predicted noise, we calculate the loss values
5. We then calculate the gradients and update the model weights.
![](https://i.imgur.com/bq0PLQI.jpg)
DDPM 的過程就像是將一個資料分布形成的 manifold 推到邊界之外,然後再訓練一個模型把他推回來,但是一個在 manifold 外面的雜訊會有無數種可能,這是不可能計算得出來的
![](https://i.imgur.com/0xy2b37.png)
![](https://i.imgur.com/bsUoyDb.png)
# Forward & Reverse diffusion
## Forward diffusion process: Markov Chain
Diffusion 的關鍵機制在於透過固定的 forward diffusion process 漸漸的摧毀資料的結構(分布),然後再利用 reverse diffusion process 來恢復資料的結構。利用這種方式訓練完 Diffusion model 之後它就可以從一個完全的 noise 產生出 image。
- 這個訓練完的 diffusion model 就是一個參數化的 Markov chain
- 所謂的 Markov chain 就是一系列會隨時間進行而呈現出不同可能的現象或狀態,其中每個時間點的現象或狀態,就稱為狀態空間。
- 下圖例子有三個狀態空間,分別是停滯市場(Stagnant), 熊市(Bear) 和牛市 (Bull),以及馬可夫鍊的狀態轉移矩陣 $P$
![](https://i.imgur.com/4mH4jGE.png)
- Markov Chain 每一個時間點的狀態皆為根據上一個時間點的狀態而產生
$h^t=h^{t−1}A$
$y^t=h^tB$
- y 為外部狀態,h 為內部狀態,通常我們只會知道外部狀態
- 例如我們目前有 Steven, Jack, Wayne 三個人一年來每一週買飲料的歷史資料,當給定某一週買飲料的清單時,我們可以根據歷史資料找出每一個人的購買習慣,也就是計算出 transition matrix 來訓練模型,最後利用這個模型來判斷下一週的飲料可能是哪一個人購買的。
- 長得像 RNN 但 Markov Chain 更簡單
![](https://i.imgur.com/Nw9R1A5.png)
## Modelling Markov chain as forward diffusion process
- Forward diffusion process (q) 使用馬可夫鏈(Markov Chain)來逐漸對資料加入 Gaussian noise 直到資料的結構被完全破壞成為完全的 noise
- $B_n$為 variance,介於 0~1 之間,通常越後面的 step 會使用越大的 variance,這是一開始就要給定的一系列超參數,如 [GLIDE](https://arxiv.org/abs/2112.10741) 會使用線性插值的方式給定
![](https://i.imgur.com/AqRnaT4.jpg)
```python
def gather(consts: Tensor, t: torch.Tensor) -> Tensor:
"""Gather consts for $t$ and reshape to feature map shape"""
c = consts.gather(-1, t)
return c.reshape(-1, 1, 1, 1)
# Calculate q
def q_xt_xtminus1(xtm1: Tensor, t: Tensor) -> Tensor:
beta = torch.linspace(0.0001, 0.04, N_STEPS) # all hyper-parameters
mean = gather(1. - beta, t) ** 0.5 * xtm1 # √(1−βt)*xtm1
var = gather(beta, t) # βt I
eps = torch.randn_like(xtm1)
return mean + (var ** 0.5) * eps
```
## Reverse diffusion process
- 這裡是從最後一個時間點往回推到第0個時間點
![](https://i.imgur.com/VWPsxGJ.jpg)
Reverse diffusion process (p) 則是訓練一個 **diffusion model** 來**逐漸**對已經充滿噪音的圖片做降噪
- 強調逐漸是因為在 diffusion 的理念中這樣更好**控制** (taking many small steps is more tractable than a large step)
- 最大的問題在於如何反向 diffusion process (如何 denoising) ?
![](https://i.imgur.com/ias0HEo.png)
![](https://i.imgur.com/E0HU4oO.png)
![](https://i.imgur.com/pipT22N.jpg)
所謂的反向 diffusion process 其實就是在尋找每一個 time step 所的 noise 量有多大,也就是等同於在問**如何近似出 $p(x_0,...,x_N)$?**
![](https://i.imgur.com/vfBHFp7.jpg)
如何將 $q(x_0,...,x_N)$ 近似於 $p(x_0,...,x_N)$? → 最小化兩者的 [KL Divergence](https://hackmd.io/uWPmJsNuRL6lmpMCjfNyDg?view#KL-divergence-KL-%E6%95%A3%E5%BA%A6) → 最大化 ELBO (Evidence Lover Bound),這個推導我們需要的就是 variational inference
![](https://i.imgur.com/5aieys4.png)
> 這裡我們所面臨的問題是給定一張圖片 x 和 latent vector z,我們想知道後驗機率 p(z|x),但這東西是真實的後驗機率,這東西並沒有辦法求解,所以才需要去使用近似估計
# Training of Diffusion model: Variational Inference
- 他借用了 VAE 針對 AutoEncoder 中間所產生的 latent vector sample 出來的分布做變化的技巧來將 Diffusion model 每個中間的過程所產生出來的 z 都當成是 latent vector,就可以使用 variational lower bound 來近似出 p(x) 的 lower bound ≥ variational lower bound
- 所謂 VAE 的變化技巧指的是因為從高斯分布 $N(\mu,\sigma)$中採樣會造成採樣過程不可微,會導致無法用模型去學,所以 VAE 就從標準常態分布 $N(0,1)$中採樣出來 $\epsilon$,然後利用 $z={\sigma}*{\epsilon}+\mu$ 的方式讓採樣過程維持可導
- Diffusion model 其實就像是 Encoder 固定的 VAE
- 完整的推導請見 [https://lilianweng.github.io/posts/2021-07-11-diffusion-models/](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)。
![](https://i.imgur.com/SCZ2JBo.png)
- 先從 KL 散度開始,這是一個用來衡量兩個資料分布距離的 metrics,當我們把它寫成期望值的型式的時候就可以產生出 ELBO
$KL(q(z)||p(z|x) = E[(log(q(z))] - E[(p(z,x))] + log(p(x))$
→ $ELBO(q) = E[log(p(z,x))] - E[log(q(z))]$
→ $log(p(x)) = ELBO(q) + KL(q(z)||p(z|x))$
- q(z) 為目標分布,p(z|x) 為預測出來的分布,log(p(x))則代表我們目前所擁有的資料分布
- 此時左邊 log(p(x)) 是常數,那麼最小化 KL 就等同於最大化 ELBO
- 而 ELBO 被稱為下界的原因是因為 KL 散度有非負的特性,所以其實可得到
$log(p(x)) = ELBO(q) + KL(q(z)||p(z|x)) >= ELBO(q)$
## KL & ELBO 推導過程
- KL Divergence 期望值形式的推導過程
![](https://i.imgur.com/X95XWdH.png)
- ELBO 詳細推導過程
![](https://i.imgur.com/ggB8ckR.png)
## loss function: 計算預測噪音和實際噪音的 L2 loss
- 訓練的目標是找出 variational lower bound 來最小化 KL divergence of q and p,經過論文中一連串的數學轉換之後可以得到下面這串
![](https://i.imgur.com/SteyE54.png)
- $L_T$ 計算的是模型最後算出來的噪音分布和先驗分布的 KL 散度,這兩者都是近似為 $N(0,I)$,可視為 0
- $L_{t-1}$ 則是計算估計分布 $p_{\theta}(x_{t-1}|x_t)$和真實後驗分布 $q(x_{t-1}|x_t,x_0)$的 KL 散度
- $L_0$ 則是因為 DDPM 會將圖片 normalize 到 (-1, 1) 之間,這也是個常數
作者接下來**固定了 variance** 因此模型只需要計算高斯雜訊的 mean,所以左邊那串才會是常數,他固定的原因是發現這樣訓練更穩定
![](https://i.imgur.com/8LKkImT.png)
- 這不只簡化了 $p_{\theta}(x_{t-1}|x_t)$,也令他其中的 variance ${\sigma}^2$ 可以被設為常數 $B$
![](https://i.imgur.com/YjkhBpR.png)
最後會發現左右邊都被移除了只剩中間
![](https://i.imgur.com/1DikwQQ.png)
而中間那串就是估計分布和真實後驗分布之間的 KL 散度,將參數帶入 KL 散度的公式之後變成
![](https://i.imgur.com/Vbnhuft.png)
- KL 散度公式:
![](https://i.imgur.com/xKKsg6l.png)
帶入可得 $L_{t-1}$ 就是
![](https://i.imgur.com/hYNGQaR.png)
- 從這個公式可以發現我們希望網路學習的是他所預測出來的噪音的 $\mu_{\theta}$ 和後驗分布的 $\hat{\mu}(x_t, x_0)$ 一致,只是 DDPM 發現這並不是最好的選擇,所以他又對其做了重參數,他令
![](https://i.imgur.com/vTMzsXt.png)
帶入之後可得
![](https://i.imgur.com/dSnPq7Q.png)
這串超級複雜,所以我們可以也對 $\mu_{\theta}$做重參數
![](https://i.imgur.com/xlv200m.png)
一樣帶入
![](https://i.imgur.com/SBTG6vq.png)
最後 DDPM 再去掉了所有不同 t 的權重係數來簡化得到
![](https://i.imgur.com/fVkxMEW.png)
可寫成 $L_{\text{diffusion}}=\mathbb{E}_{t, \mathbf{x}_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right],$
- 這裡的 t 是在 [1, T] 內取值,例如 t=1 就對應 $L_0$
- $\epsilon$ 就是真實的噪音,而 $\epsilon_{\theta}$就是預測噪音
- 從他的實驗結果可以發現預測噪音本身比預測噪音的均值還要好
![](https://i.imgur.com/Zvo528i.png)
## Summary of Diffusion process
隨機選擇一個訓練樣本 → 從 1~T 中隨機抽樣一個 t → 隨機產生噪音 → 計算當前所產生的噪音數據 (藍色標記處) → 輸入模型預測噪音 → 計算模型所產生的噪音與預測噪音的 L2 loss → 計算梯度並更新網路
![](https://i.imgur.com/ntsL2Yx.png)
# 模型架構
> 什麼! 竟然是…
[The Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion)
- 整個架構其實是一個 U-Net,首先會先使用一層 Conv layer 以及 sinusoidal position embedding 作為 time embedding 來對圖片加入噪音
- 然後接著一連串的下採樣 stage,每一個下採樣 stage 裡面共有 2個 ResNet blocks + groupnorm + self-attention + residual connection + downsample 操作
- 到了中間的部分,又交錯使用了 ResNet + attention
- 接下來是一連串的上採樣,每一個 stage 中一樣有 2個 ResNet blocks + groupnorm + self-attention + residual connection + upsample 操作
- 最後,一個 ResNet block + Conv layer 會作為輸出層
![](https://i.imgur.com/yIHHDmW.jpg)
![](https://i.imgur.com/ZUvnDsH.jpg)
```python
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
self_condition=False,
resnet_block_groups=4,
):
super().__init__()
# determine dimensions
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out)
if not is_last
else nn.Conv2d(dim_in, dim_out, 3, padding=1),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in)
if not is_last
else nn.Conv2d(dim_out, dim_in, 3, padding=1),
]
)
)
self.out_dim = default(out_dim, channels)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
def forward(self, x, time, x_self_cond=None):
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim=1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
# Down sampling Blocks
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# Middle Blocks
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# Upsampling blocks
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim=1)
# Output block
x = self.final_res_block(x, t)
return self.final_conv(x)
```
以上還只是 unconditional 情況下的 DDPM,但實際上還有 Conditional 的 DDPM,其成果就是我們現今所看到的 DALL-E 2, Imagen, Stable diffusion,但我猜我講不動了。
# 缺點
- DDPM 並沒有所謂超越其他生成模型之說,他的計算複雜度會由於馬可夫鏈而非常巨大,整個採樣的速度會很慢,需要走過完整的 T 步才能完成一次訓練,在 DDPM 中 T=1000
- DDPM 生成出來的東西在 FID 上並沒有超越 GAN,效果待驗證
# Reference
[https://mpatacchiola.github.io/blog/2021/01/25/intro-variational-inference.html](https://mpatacchiola.github.io/blog/2021/01/25/intro-variational-inference.html)
[https://zhuanlan.zhihu.com/p/385341342](https://zhuanlan.zhihu.com/p/385341342)
[https://yang-song.net/blog/2021/score/](https://yang-song.net/blog/2021/score/)
[https://lilianweng.github.io/posts/2021-07-11-diffusion-models/](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
[https://zhuanlan.zhihu.com/p/563661713](https://zhuanlan.zhihu.com/p/563661713)
[https://medium.com/ai-blog-tw/邊實作邊學習diffusion-model-從ddpm的簡化概念理解-4c565a1c09c](https://medium.com/ai-blog-tw/%E9%82%8A%E5%AF%A6%E4%BD%9C%E9%82%8A%E5%AD%B8%E7%BF%92diffusion-model-%E5%BE%9Eddpm%E7%9A%84%E7%B0%A1%E5%8C%96%E6%A6%82%E5%BF%B5%E7%90%86%E8%A7%A3-4c565a1c09c)