---
title: 盡量白話的Diffusion Model基礎知識整理
---
整理給自己看的Diffusion Model基本知識
為了方便理解寫得很白話(不負責任版)
更新日期: 2024/11/08
## **Denoising Diffusion Probabilistic Models (DDPM)**
主要源自[Denoising Diffusion Probabilistic Models(2020)](https://arxiv.org/abs/2006.11239)
一般我們可以假設人眼認為合理的影像應該會滿足某種特定的分布,用數學來表達就是一張真實圖像 $\mathbf{x}_0$ 需要滿足 $\mathbf{x}_0 \sim q(\mathbf{x})$ 這個分布(這邊的$\mathbf{x}_0$不見得一定要是圖像,你想像得出來任何適用以上假設的case都可以)。當然想也知道像圖像這類複雜的分布是很難人為的用任何已知的數學形式去表達出來的,所以有人想到是不是可以用深度學習來模擬這樣的分布,藉此幫助我們做圖像生成。如果粗淺的理解,擴散這個概念就是不管你的初始狀態 $\mathbf{x}_0$ 長怎麼樣,反正經過長時間的擴散,各種不同的 $\mathbf{x}_0$ 他們原始的特徵都會被稀釋掉、分布上也變得趨近一致。所以如果神經網路能逆轉這個擴散行為,用大量的擴散數據去學習反擴散的可能性,是不是就可以反過來從擴散的盡頭推測出一個合理的 $\mathbf{x}_0$?透過這個概念,以下的流程就被設計了出來,能夠對Diffusion為基礎的圖像生成模型進行訓練及採樣:
<center><img src="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/DDPM.png" width="90%" alt="img01"/></center>
### **Forward Diffusion**
在$\mathbf{x}_0 \sim q(\mathbf{x})$這個前提下,如果對$\mathbf{x}_0$添加Gaussian noise,並重複 $T$ 次,生成出 $\mathbf{x}_1, ..., \mathbf{x}_T$ 一系列的加噪圖像。在 $T$ 夠大的情況下,最後生成出來的$\mathbf{x}_T$應該會趨近於Gaussian noise。以上操作數學上表示為: $$\begin{aligned}
q(\mathbf{x}_t|\mathbf{x}_{t-1})
&=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I}) \\
&=\sqrt{1-\beta_t}\mathbf{x}_{t-1}+\sqrt{\beta_t}\epsilon\qquad\epsilon\sim\mathcal{N}(0, \mathbf{I})
\end{aligned}$$ 其中,每一步加噪的Gaussian noise強度由 $\{ \beta_t \in (0, 1) \}_{t=0}^{T}$ 控制,$\beta_t$ 會隨著 $T$ 的上升也跟這越來越大,另外 $\beta_t$ 也有很多種不同的schedule設計,包含linear、quadratic、cosine等等,會影響在 $T$ 個時間步階中圖像被加噪或去噪的趨勢。<br>
以上這個計算過程有一個好處,就是要得出 $\mathbf{x}_t$ 時,不需要真的把中間過程的每一張圖都算出來,而是可以透過reparameterize的方式簡化(<font color=#800000>**對計算過程沒興趣可以直接跳到下個紅字**</font>):<br>
令$\alpha_t=1-\beta_t$ 且 $\bar{\alpha}_t=\prod_{i=1}^t \alpha_i$ $$\begin{aligned}
\mathbf{x}_t
&= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1 - \alpha_t}\boldsymbol{\epsilon}_{t-1}\quad\text{ ;where } \boldsymbol{\epsilon}_{t-1}, \boldsymbol{\epsilon}_{t-2}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\
&= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \bar{\boldsymbol{\epsilon}}_{t-2} \quad\text{ ;where } \bar{\boldsymbol{\epsilon}}_{t-2} \text{ merges two Gaussians (*).} \\
&= \dots \\
&= \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon} \\
q(\mathbf{x}_t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I})
\end{aligned}$$
這個簡化之所以可以成立是因為兩個分布的和是這樣計算的: $$\mathcal{N}(\mathbf{0}, \sigma_1^2\mathbf{I})+\mathcal{N}(\mathbf{0}, \sigma_2^2\mathbf{I})=\mathcal{N}(\mathbf{0}, (\sigma_1^2 + \sigma_2^2)\mathbf{I})$$ 這使得 $\epsilon$ 項前面的係數(也就是標準差)能夠在推導的時候輕易的被合併。$$\sqrt{(1 - \alpha_t) + \alpha_t (1-\alpha_{t-1})} = \sqrt{1 - \alpha_t\alpha_{t-1}}$$
這邊講了那麼多<font color=#800000>**其實結論就是,我們要在給定 $\mathbf{x}_0$ 時,求出 $\mathbf{x}_t$ 只需要做一次以下的計算就夠了**</font>: $$q(\mathbf{x}_t \vert \mathbf{x}_0)=\sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon}\tag{1}$$
其中, $\alpha_t=1-\beta_t\text{ ; }\bar{\alpha}_t=\prod_{i=1}^t \alpha_i$。我看到有些人之所以會誤會DDPM的訓練過程,以為訓練需要真的做數百次加噪的迭代,就是因為不清楚這個結論可以直接把迭代過程一步到位。
### **Reverse Diffusion**
如果我們可以逆轉上述流程,反過來用 $q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)$ 去推測 $\mathbf{x}_{t-1}$,理論上就能夠從純粹的Gaussian noise $\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ 中一步步還原出真實圖像 $\mathbf{x}_0 \sim q(\mathbf{x})$。然而,人類很難用現有的數學知識解出 $q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)$。但是,如果我們先假設 $\mathbf{x}_t$ 是從 $\mathbf{x}_0$ 一路加噪過來的,考慮 $\mathbf{x}_0$ 這個已知的條件後,式子就能改成: $$q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) \tag{2}$$
而且式中的均值 $\tilde{\mu}_t$ 和方差 $\tilde{\beta}_t$ 可以用貝氏定理去推導出解析解(就是跟國中學的一元二次公式解差不多的意思): $$q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0)
= q(\mathbf{x}_t \vert \mathbf{x}_{t-1}, \mathbf{x}_0) \frac{ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) }{ q(\mathbf{x}_t \vert \mathbf{x}_0) }$$
這邊透過貝氏定理就把Reverse的條件機率改成我們已知的形式了,可以直接找Forward Diffusion中的一些結果代進去。跳過複雜的數學推導過程,總之可以得到的 $\tilde{\mu}$ 和 $\tilde{\beta}_t$ 解析解:$$\begin{aligned}
\tilde{\mu}_t (\mathbf{x}_t, \mathbf{x}_0)
&= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0\\
\tilde{\beta}_t
&={\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t}
\end{aligned}$$
同時把公式(1)裡的項調換一下,就可以得到 $\mathbf{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t)$ 代入 $\tilde{\mu}_t$ 中,得到:$$\tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \Big( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big)$$
解到這邊可以發現,除了 $\epsilon_t$ 之外,其他的參數都是已知的(能透過Forward Diffusion中人為設定的 $\beta_t$ 推算出來),最後我們就可以把求解 $\epsilon_t$ 這個重責大任丟給神經網路處理了。所以,<font color=#800000>**Reverse Diffusion的結論就是,我們要訓練一個神經網路 $p_\theta$ 來預測 $\epsilon_t$ ,來解出公式(2)中的分布的均值與方差,如此一來便可以採樣出 $q(\mathbf{x}_{t-1} \vert \mathbf{x}_t)$ 的結果了**</font>。
## Training & Sampling
<center><img src="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/DDPM-algo.png" width="100%" alt="img01"/></center>
### **Training**
DDPM的訓練很直接,在原圖為 $\mathbf{x}_0$ 的條件下,先從 $1 \sim T$ 之間隨機選擇一個時間步階 $t$ ,再採樣一個跟 $\mathbf{x}_0$ 尺寸一樣的高斯噪音 $\epsilon$ ,然後套用公式(1)將 $\mathbf{x}_0$ 加噪,生成出 $\mathbf{x}_t$。神經網路要根據 $\mathbf{x}_t$ 和 $t$ 兩個input去推斷 $\epsilon$,而損失函數就計算真實值 $\epsilon$ 和預測值 $\epsilon_\theta$ 之間的MSE就好。
```
class DDPM_Trainer(nn.Module):
"""
Vanilla DDPM訓練時使用的策略
"""
def __init__(self, model, beta_1=1e-4, beta_T=0.02, T=1000):
super().__init__()
self.model = model
self.T = T
betas = torch.linspace(beta_0, beta_T, T)
alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, dim=0)
# for q(x_t | x_{t-1})
self.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar).float())
self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar).float())
def forward(self, x_0, condit):
B = x_0.shape[0]
# 根據batch size隨機生成一串time step
t = torch.randint(0, self.T, size=(B,), device=x_0.device)
# time step對應的alpha
sqrt_alphas_bar_t = torch.gather(self.sqrt_alphas_bar, 0, t).to(x_0.device)
sqrt_one_minus_alphas_bar_t = torch.gather(self.sqrt_one_minus_alphas_bar, 0, t).to(x_0.device)
# 生成時間為t時的noise,並施加在原圖上
noise = torch.randn_like(x_0)
x_t = (sqrt_alphas_bar_t.view(-1, 1, 1, 1) * x_0 + sqrt_one_minus_alphas_bar_t.view(-1, 1, 1, 1) * noise)
# 計算損失函數, 我這邊有考慮以圖片作為condition的方式, 但我是參考SR3的作法,直接concat在channel的dimension上
if condit is not None:
x_t = torch.cat([condit, x_t], dim=1)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
```
### Sampling
sampling過程會先從隨機的高斯噪音 $\mathbf{x}_{T}$ 開始,跑T次Reverse Diffusion的計算,也就是透過模型的輸出,計算公式(2)中的平均值跟標準差,來推算 $\mathbf{x}_{t-1}$,直到計算出原圖 $\mathbf{x}_{0}$ 為止。
```
class DDPM_Sampler(nn.Module):
def __init__(self, model, beta_1=1e-4, beta_T=0.02, beta_scdl='linear', T=1000):
"""
Vanilla DDPM inference時使用的策略
"""
super().__init__()
self.model = model
self.T = T
betas = torch.linspace(beta_0, beta_T, T)
alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = torch.cat((torch.ones(1), alphas_bar[:-1]))
# for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_bar_prev) / (1. - alphas_bar) # beta_wave_t
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
self.register_buffer('posterior_log_variance_clipped', torch.log(np.maximum(posterior_variance, 1e-20)))
def p_mean_variance(self, x_t, t, condit=None):
"""
計算reverse diffusion條件機率中, 解析解對應的均值及方差
"""
B = x_t.shape[0]
tensor_t = torch.ones((B,), device=x_t.device) * t
if condit is not None:
x_t = torch.cat([condit, x_t], dim=1)
eps = self.model(x_t, tensor_t)
assert x_t.shape == eps.shape
mean = self.coeff1[t] * x_t - self.coeff2[t] * eps
log_var = self.posterior_log_variance_clipped[t]
return mean, log_var
def forward(self, condit, x_T):
x_t = x_T
for t in tqdm(reversed(range(self.T)), dynamic_ncols=True, desc='DDPM Sampling', total=self.T):
mean, var= self.p_mean_variance(condit, x_t, t)
# 相當於公式(2)的採樣結果
noise = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
x_t = mean + noise * torch.exp(0.5 * var)
x_0 = x_t
return torch.clip(x_0, -1, 1)
```
## Reference
1. [What are Diffusion Models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) (這篇很細,數學推導過程很完整)
2. [Diffusion Models:生成扩散模型](https://yinglinzheng.netlify.app/diffusion-model-tutorial/) (簡中的,寫的也還行)