---
# System prepended metadata

title: UNet
tags: [Deep learning]

---

---
tags: Deep learning
---

# UNet

## Introduction
UNet是在語意分割問題中十分泛用的架構。基於FCN，UNet可以image-to-image的處理segmentation的問題，在像素級別的分類問題上能達到很好的效果。它也是在kaggle的醫學影像比賽中的常勝軍。另一方面，UNet的結構調整相對Mask-RCNN來的容易，許多論文提出的深度學習架構都是基於UNet，對應不同的任務調整內部的架構。

因此，透過學習UNet架構與其成功的關鍵思維，而不是只是單純把它當成黑盒子來使用，我認為是有其必要的。或許未來也可以針對不同的影像分割問題，基於UNet設計出更強大的DL架構。

## Arichitecture

[參考架構:usuyama/pytorch-unet](https://github.com/usuyama/pytorch-unet)

``` python
class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        
        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
       
        
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)
 
        
        
        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)
        
        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)
        
        return out
```
下圖是參考Github的pytorch架構，以input3x512x512大小為例，trace code之後畫出來的架構圖。和一般的UNet稍有不同的是，這個架構是在UNet的基礎上，將部分的convolution block換成pre trained residual block，以得到更好的模型performance。
![](https://i.imgur.com/ATz0lck.png)

### layer implementation
![](https://i.imgur.com/PrJ4D2e.png) : 
表示channel x height x width 的tensor shape
![](https://i.imgur.com/QgV9jBO.png) : 
表示一個2d convolution + ReLu  組成的Block，詳細的pytorch程式如下

``` python
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )
```
![](https://i.imgur.com/5aEO2Qp.png) : 
進行upsampling，將CxHxW的tensor變成 Cx(2xH)x(2xW)

![](https://i.imgur.com/5KvacQj.png)
concate, 將C1xHxW 與 C2xHxW 的2個tensor疊合成一個(C1+C2)xHxW 的tensor

![](https://i.imgur.com/RHPnkbc.png) :
layer0, 是ResNet18的前兩層，包含了Conv2d+BatchNormalize+ReLu 的架構

![](https://i.imgur.com/mRYPibi.png) : 
layer1, 是ResNet18的第3,4層，包含了一層Maxpooling+兩層resnet basic block。使CxHxW的tensor變成Cx(H/2)x(W/2)

![](https://i.imgur.com/m26EJ8m.png) : 
layer2/3/4, 是ResNet18的第5/6/7層，都是由兩層的resnet basic block組成。使CxHxW的tensor變成(2xC)x(H/2)x(W/2)

**ResNet Basic Block**:

![](https://i.imgur.com/F2PRDV0.png)

如以下程式，這是由pytorch官方所實現的Residual Block架構。trace code之後畫出來的架構大致如上圖所示。


``` python
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):    
    # inplanes代表input channel，planes代表output channel。
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
```









