--- tags: Deep learning --- # UNet ## Introduction UNet是在語意分割問題中十分泛用的架構。基於FCN,UNet可以image-to-image的處理segmentation的問題,在像素級別的分類問題上能達到很好的效果。它也是在kaggle的醫學影像比賽中的常勝軍。另一方面,UNet的結構調整相對Mask-RCNN來的容易,許多論文提出的深度學習架構都是基於UNet,對應不同的任務調整內部的架構。 因此,透過學習UNet架構與其成功的關鍵思維,而不是只是單純把它當成黑盒子來使用,我認為是有其必要的。或許未來也可以針對不同的影像分割問題,基於UNet設計出更強大的DL架構。 ## Arichitecture [參考架構:usuyama/pytorch-unet](https://github.com/usuyama/pytorch-unet) ``` python class ResNetUNet(nn.Module): def __init__(self, n_class): super().__init__() self.base_model = models.resnet18(pretrained=True) self.base_layers = list(self.base_model.children()) self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) self.layer0_1x1 = convrelu(64, 64, 1, 0) self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) self.layer1_1x1 = convrelu(64, 64, 1, 0) self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) self.layer2_1x1 = convrelu(128, 128, 1, 0) self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) self.layer3_1x1 = convrelu(256, 256, 1, 0) self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) self.layer4_1x1 = convrelu(512, 512, 1, 0) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv_up3 = convrelu(256 + 512, 512, 3, 1) self.conv_up2 = convrelu(128 + 512, 256, 3, 1) self.conv_up1 = convrelu(64 + 256, 256, 3, 1) self.conv_up0 = convrelu(64 + 256, 128, 3, 1) self.conv_original_size0 = convrelu(3, 64, 3, 1) self.conv_original_size1 = convrelu(64, 64, 3, 1) self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) self.conv_last = nn.Conv2d(64, n_class, 1) def forward(self, input): x_original = self.conv_original_size0(input) x_original = self.conv_original_size1(x_original) layer0 = self.layer0(input) layer1 = self.layer1(layer0) layer2 = self.layer2(layer1) layer3 = self.layer3(layer2) layer4 = self.layer4(layer3) layer4 = self.layer4_1x1(layer4) x = self.upsample(layer4) layer3 = self.layer3_1x1(layer3) x = torch.cat([x, layer3], dim=1) x = self.conv_up3(x) x = self.upsample(x) layer2 = self.layer2_1x1(layer2) x = torch.cat([x, layer2], dim=1) x = self.conv_up2(x) x = self.upsample(x) layer1 = self.layer1_1x1(layer1) x = torch.cat([x, layer1], dim=1) x = self.conv_up1(x) x = self.upsample(x) layer0 = self.layer0_1x1(layer0) x = torch.cat([x, layer0], dim=1) x = self.conv_up0(x) x = self.upsample(x) x = torch.cat([x, x_original], dim=1) x = self.conv_original_size2(x) out = self.conv_last(x) return out ``` 下圖是參考Github的pytorch架構,以input3x512x512大小為例,trace code之後畫出來的架構圖。和一般的UNet稍有不同的是,這個架構是在UNet的基礎上,將部分的convolution block換成pre trained residual block,以得到更好的模型performance。 ![](https://i.imgur.com/ATz0lck.png) ### layer implementation ![](https://i.imgur.com/PrJ4D2e.png) : 表示channel x height x width 的tensor shape ![](https://i.imgur.com/QgV9jBO.png) : 表示一個2d convolution + ReLu 組成的Block,詳細的pytorch程式如下 ``` python def convrelu(in_channels, out_channels, kernel, padding): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel, padding=padding), nn.ReLU(inplace=True), ) ``` ![](https://i.imgur.com/5aEO2Qp.png) : 進行upsampling,將CxHxW的tensor變成 Cx(2xH)x(2xW) ![](https://i.imgur.com/5KvacQj.png) concate, 將C1xHxW 與 C2xHxW 的2個tensor疊合成一個(C1+C2)xHxW 的tensor ![](https://i.imgur.com/RHPnkbc.png) : layer0, 是ResNet18的前兩層,包含了Conv2d+BatchNormalize+ReLu 的架構 ![](https://i.imgur.com/mRYPibi.png) : layer1, 是ResNet18的第3,4層,包含了一層Maxpooling+兩層resnet basic block。使CxHxW的tensor變成Cx(H/2)x(W/2) ![](https://i.imgur.com/m26EJ8m.png) : layer2/3/4, 是ResNet18的第5/6/7層,都是由兩層的resnet basic block組成。使CxHxW的tensor變成(2xC)x(H/2)x(W/2) **ResNet Basic Block**: ![](https://i.imgur.com/F2PRDV0.png) 如以下程式,這是由pytorch官方所實現的Residual Block架構。trace code之後畫出來的架構大致如上圖所示。 ``` python class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): # inplanes代表input channel,planes代表output channel。 super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out ```