# 〈 Diffusion Model 論文研究與實作心得 Part.2 〉 U-Net 模型架構介紹與實作 --- Tu 2023/2/14 ## 一、前言 在上一篇文章[〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理](https://hackmd.io/@Tu32/B1-m6Tuai)中,我完成了對圖片加入雜訊的部分,因此接下來就輪到模型的拆解。 ## 二、U-Net 模型簡介 ![](https://i.imgur.com/H7zAGYE.png) 圖片來源:【Deep Learning for Image Segmentation: U-Net Architecture】 在DDPM論文中,作者使用了U-Net這種模型架構來進行訓練。U-Net是Auto-encoder的變種,可以看到下方一樣有一個bottleneck的部分,且輸入和輸出圖片的大小相同。U-Net在image segmantation的領域有著重大貢獻,與傳統的Auto-encoder不同的是,U-Net在encoder和decoder之間有使用residual connection,以更好的保留原始圖片的特徵。 ## 三、U-Net 架構實作 若要進行U-Net的實作,可以拆解成下方幾個的零件實作。 * 兩層CNN的Block * time embedding * Down(左半邊的Encoder,兩層CNN加上Maxpooling) * Up(右半邊的Decoder,兩層CNN加上Upsample) * self attention * residual connection #### 1. 雙層CNN 先從最常用到的著手,先設計一個有兩層CNN的Block,在之後的地方都會用到 ```python class DoubleConv(nn.Module): def __init__(self): pass def forward(self): pass ``` 填入模型 ```python class DoubleConv(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), nn.GroupNorm(1, out_c), #equivalent with LayerNorm nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(out_c, out_c, kernel_size=3, padding=1), nn.GroupNorm(1, out_c), #equivalent with LayerNorm nn.ReLU() ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x ``` #### 2. Time Embedding 在訓練U-Net的時候,我一開始以為輸入是一張圖片,輸出只要給出被修復過的圖片就好。但其實這樣有一個問題,就是模型不知道不同timestep的圖片之間的差別,導致模型需要直接面對不同雜訊強度的圖片並進行修復。 embedding的概念簡單來說就是把一個單獨的值加工成一個tensor。比如我們對模型輸入圖片和一個整數(timestep),我們能透過embedding將那個整數換成一個tensor,變成讓模型更容易學習的形式。而DDPM的作者選擇使用Sinusoidal Position Embedding來為單獨timestep做embedding。 ![](https://i.imgur.com/wtIY9yu.jpg) 看起來很厲害的Sinusoidal Position Embedding (圖源:[A Gentle Introduction to Positional Encoding in Transformer Models, Part 1](https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/)) 這個問題有點像Transformer在訓練的時候用attention訓練時,需要將文字再加上一個positional embedding的概念相同,我們也需要為不同雜訊強度的圖片加上一個time embedding來告訴模型這是甚麼強度的圖片。 ```python def pos_encoding(t, channels): t = torch.tensor([t]) inv_freq = 1.0 / ( 10000 ** (torch.arange(0, channels, 2).float() / channels) ) pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) return pos_enc ``` 接收兩個整數後回傳一個embedded好的Tensor,範例如下 ```python pos_encoding(10, 16) #timestep = 10 ``` ``` tensor([[-0.5440, -0.0207, 0.8415, 0.3110, 0.0998, 0.0316, 0.0100, 0.0032, -0.8391, -0.9998, 0.5403, 0.9504, 0.9950, 0.9995, 0.9999, 1.0000]]) ``` 當然這樣一個tensor肯定不能直接與圖片的tensor相加,在size上還需要調整,這個在後面會有提到。 #### 3. Down & Up 接下來是Down和Up,簡單概念就是進行Maxpooling或Upsample後再加個DoubleConv 首先是Down的部分 ```python class Down(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.down = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_c,out_c,first_residual=True), ) def forward(self, x): x = self.down(x) return x ``` 基本架構差不多是這樣,但是不要忘了我們還要為圖片加上time embedding ```python class Down(nn.Module): def __init__(self, in_c, out_c, emb_dim=128): super().__init__() self.down = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_c,out_c), ) self.emb_layer = nn.Sequential( nn.ReLU(), nn.Linear(emb_dim, out_c), ) def forward(self, x, t): x = self.down(x) #擴充兩個dimension,然後使用repeat填滿成和圖片相同(如同numpy.tile) t_emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) return x + t_emb ``` Up的架構基本相同,但是如果看上面的圖,可以看到Up還需要接收一個類似residual connection的輸入,所以在forward()裡面會多一個`skip_x`與`x`接起來。 ```python class Up(nn.Module): def __init__(self, in_c, out_c, emb_dim=128): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = DoubleConv(in_c,out_c) self.emb_layer = nn.Sequential( nn.SiLU(), nn.Linear(emb_dim, out_c), ) def forward(self, x, skip_x, t): x = self.up(x) x = torch.cat([skip_x, x], dim=1) x = self.conv(x) emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) return x + emb ``` #### 4. Self Attention Block 這個部分沒打算細講(因為我也沒完全懂),之後可能會再寫一篇Attention is all you need的研究心得之類的。簡單來說Self Attention可以想成輸入一個向量,結果再輸出一個向量的黑盒子。(這邊直接照抄Outlier的程式碼) ```python class SelfAttention(nn.Module): def __init__(self, channels, size): super(SelfAttention, self).__init__() self.channels = channels self.size = size self.mha = nn.MultiheadAttention(channels, 4, batch_first=True) self.ln = nn.LayerNorm([channels]) self.ff_self = nn.Sequential( nn.LayerNorm([channels]), nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels), ) def forward(self, x): x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) x_ln = self.ln(x) attention_value, _ = self.mha(x_ln, x_ln, x_ln) attention_value = attention_value + x attention_value = self.ff_self(attention_value) + attention_value return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) ``` #### 5. 組裝U-Net 最後我們把來把上面寫的東西組裝起來 ```python class UNet(nn.Module): def __init__(self, c_in=3, c_out=3, time_dim=128, device="cuda"): super().__init__() self.device = device self.time_dim = time_dim self.inc = DoubleConv(c_in, 64) #(b,3,64,64) -> (b,64,64,64) self.down1 = Down(64, 128) #(b,64,64,64) -> (b,128,32,32) self.sa1 = SelfAttention(128, 32) #(b,128,32,32) -> (b,128,32,32) self.down2 = Down(128, 256) #(b,128,32,32) -> (b,256,16,16) self.sa2 = SelfAttention(256, 16) #(b,256,16,16) -> (b,256,16,16) self.down3 = Down(256, 256) #(b,256,16,16) -> (b,256,8,8) self.sa3 = SelfAttention(256, 8) #(b,256,8,8) -> (b,256,8,8) self.bot1 = DoubleConv(256, 512) #(b,256,8,8) -> (b,512,8,8) self.bot2 = DoubleConv(512, 512) #(b,512,8,8) -> (b,512,8,8) self.bot3 = DoubleConv(512, 256) #(b,512,8,8) -> (b,256,8,8) self.up1 = Up(512, 128) #(b,512,8,8) -> (b,128,16,16) because the skip_x self.sa4 = SelfAttention(128, 16) #(b,128,16,16) -> (b,128,16,16) self.up2 = Up(256, 64) #(b,256,16,16) -> (b,64,32,32) self.sa5 = SelfAttention(64, 32) #(b,64,32,32) -> (b,64,32,32) self.up3 = Up(128, 64) #(b,128,32,32) -> (b,64,64,64) self.sa6 = SelfAttention(64, 64) #(b,64,64,64) -> (b,64,64,64) self.outc = nn.Conv2d(64, c_out, kernel_size=1) #(b,64,64,64) -> (b,3,64,64) def pos_encoding(self, t, channels): t = torch.tensor([t]) inv_freq = 1.0 / ( 10000 ** (torch.arange(0, channels, 2).float() / channels) ) pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) return pos_enc def forward(self, x, t): # (bs,) -> (bs, time_dim) t = t.unsqueeze(-1).type(torch.float) t = self.pos_encoding(t, self.time_dim) #initial conv x1 = self.inc(x) #Down x2 = self.down1(x1, t) x2 = self.sa1(x2) x3 = self.down2(x2, t) x3 = self.sa2(x3) x4 = self.down3(x3, t) x4 = self.sa3(x4) #Bottle neck x4 = self.bot1(x4) x4 = self.bot2(x4) x4 = self.bot3(x4) #Up x = self.up1(x4, x3, t) x = self.sa4(x) x = self.up2(x, x2, t) x = self.sa5(x) x = self.up3(x, x1, t) x = self.sa6(x) #Output output = self.outc(x) return output ``` 確認一下是否能正常運作以及輸出是否正確 ```python sample = torch.randn((32, 3, 64, 64)) t = torch.randint(0, T, (32,)) model = UNet() model(sample, t).shape ``` Output: ``` torch.Size([32, 3, 64, 64]) ``` 水喔,U-Net 模型的部分搞定了 ## 四、結語 本來想多講一點的(圖片修復的部分)但寫到這裡已經快9000字了,下個部分沒意外應該就是完結了,看能不能寫完圖片修復和模型訓練。可能會再額外寫一篇Extra講如何改進什麼的,都是後話了。 ### 相關資料 https://www.youtube.com/watch?v=a4Yfz2FxXiY https://www.youtube.com/watch?v=HoKDTa5jHvg&t=1338s https://huggingface.co/blog/annotated-diffusion https://arxiv.org/pdf/2102.09672.pdf https://arxiv.org/pdf/1503.03585.pdf https://arxiv.org/pdf/2006.11239.pdf https://theaisummer.com/latent-variable-models/#reparameterization-trick https://theaisummer.com/diffusion-models/ https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/ ###### tags: `AI` `Deep Learning` `Diffusion Model`