--- Tu 2023/6/22 ***「Where there is no vision, there is no hope.」*** **- George Washington Carver** ![ran-berkovich-kSLNVacFehs-unsplash (1)](https://hackmd.io/_uploads/Hyc825BYT.jpg) ## 一、前言 三個月前,我完成了[三篇有關Diffusion Model實作的系列文章](https://hackmd.io/@Tu32/B1-m6Tuai)。本來預計如果有額外篇會拿來展示模型改進的成果或是有更好的運算能力下的訓練結果,但我最近突發奇想所以先寫一篇記錄著。 這篇的內容是有關模型改進的,如題,我把原本的U-Net換成了LinkNet。我起初根本不知道LinkNet是什麼,只是突然想到「如果把U-Net的Encoder換成pre-trained的ResNet50會怎樣」,一查才發現有這篇paper。 我在寫的當下這個模型正在訓練,我也不知道效果會如何。 ## 二、LinkNet簡介 沒啥好說的,就是把encoder的部分換成ResNet50提取特徵。 而且我沒有顯卡,所以還是用Google Colab的免費GPU在訓練我的模型,95000張64x64的圖訓練起來要快兩個小時。我本來以為把ResNet50的Layer freeze會算比較快,但看起來好像沒有。 ## 三、LinkNet實作 雖然概念是一樣的,但實作內容還是花了我不少時間和心血。大概可以拆分成這幾個步驟: 1. 下載並導入ResNet50,從中提取需要的Layers 2. Down Block重新定義 3. 重新定義BottleNeck 4. Up Block重新定義 5. 把timestep embedding加進去 我這次是參考[這個github repo](https://github.com/rawmarshmellows/pytorch-unet-resnet-50-encoder),但是他的是一般的U-Net,不符合我的需求,所以等等會講我怎麼改。 ### 下載並導入ResNet50 Pytorch的torchvision有提供很多pre-trained的model,我這邊是直接用 ```python import torchvision resnet = torchvision.models.resnet.resnet50(pretrained=True) ``` 先說我們需要的layer是input相關的3個layers,以及下面的BottleNeck(有downsize效果)。有關提取的方式等等再說。 ## Down Block重新定義 把本來的Down Block換成ResNet的層,另外加上timestep embedding,我懶得再寫一格了。 p.s 理論上要unfreeze,寫的時候才發現 ```python class DownBlockWithResnet50Unet(nn.Module): def __init__(self, index, out_c, emb_dim=128): super().__init__() resnet = torchvision.models.resnet.resnet50(pretrained=True) for param in resnet.parameters(): param.requires_grad = False # freeze all layers down_blocks = [] for i, bottleneck in enumerate(list(resnet.children())): if isinstance(bottleneck, nn.Sequential): down_blocks.append(bottleneck) self.down_block = down_blocks[index] self.emb_layer = nn.Sequential( nn.ReLU(), nn.Linear(emb_dim, out_c), ) def forward(self, x, t): x = self.down_block(x) t_emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) return x + t_emb ``` out_c要自己去看ResNet架構 ## 重新定義BottleNeck 照抄原作者的Bridge,直接上 ```python class ConvBlock(nn.Module): """ Helper module that consists of a Conv -> BN -> ReLU """ def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.with_nonlinearity = with_nonlinearity def forward(self, x): x = self.conv(x) x = self.bn(x) if self.with_nonlinearity: x = self.relu(x) return x class Bridge(nn.Module): """ This is the middle layer of the UNet which just consists of some """ def __init__(self, in_channels, out_channels): super().__init__() self.bridge = nn.Sequential( ConvBlock(in_channels, out_channels), ConvBlock(out_channels, out_channels) ) def forward(self, x): return self.bridge(x) ``` ## Up Block重新定義 原作者的Up Blocks沒有timestep embedding,所以我改了一下: ```python class UpBlockForUNetWithResNet50(nn.Module): """ Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock """ def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, upsampling_method="conv_transpose", emb_dim = 128): super().__init__() if up_conv_in_channels == None: up_conv_in_channels = in_channels if up_conv_out_channels == None: up_conv_out_channels = out_channels if upsampling_method == "conv_transpose": self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) elif upsampling_method == "bilinear": self.upsample = nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) ) self.conv_block_1 = ConvBlock(in_channels, out_channels) self.conv_block_2 = ConvBlock(out_channels, out_channels) self.emb_layer = nn.Sequential( nn.SiLU(), nn.Linear(emb_dim, out_channels), ) def forward(self, up_x, down_x, t): """ :param up_x: this is the output from the previous up block :param down_x: this is the output from the down block :return: upsampled feature map """ x = self.upsample(up_x) x = torch.cat([x, down_x], 1) x = self.conv_block_1(x) x = self.conv_block_2(x) emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) return x + emb ``` ## 完整LinkNet ```python class LinkNet(nn.Module): DEPTH = 6 def __init__(self, n_classes=3, time_dim=128): super().__init__() resnet = torchvision.models.resnet.resnet50(pretrained=True) for param in resnet.parameters(): param.requires_grad = False # freeze all layers down_blocks = [] up_blocks = [] emb_dim = 128 self.input_block = nn.Sequential(*list(resnet.children()))[:3] self.input_pool = list(resnet.children())[3] down_blocks.append(DownBlockWithResnet50Unet(0, 256)) down_blocks.append(DownBlockWithResnet50Unet(1, 512)) self.down_blocks = nn.ModuleList(down_blocks) self.bridge = Bridge(512, 512) up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, up_conv_in_channels=256, up_conv_out_channels=128)) up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, up_conv_in_channels=128, up_conv_out_channels=64)) self.up_blocks = nn.ModuleList(up_blocks) self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_dim), nn.Linear(time_dim, time_dim), nn.ReLU() ) def forward(self, x, t, with_output_feature_map=False): time = self.time_mlp(t) pre_pools = dict() pre_pools[f"layer_0"] = x x = self.input_block(x) pre_pools[f"layer_1"] = x x = self.input_pool(x) for i, block in enumerate(self.down_blocks, 2): x = block(x, time) if i == (UNetWithResnet50Encoder.DEPTH - 1): continue pre_pools[f"layer_{i}"] = x x = self.bridge(x) for i, block in enumerate(self.up_blocks, 1): key = f"layer_{UNetWithResnet50Encoder.DEPTH - 3 - i}" x = block(x, pre_pools[key], time) output_feature_map = x x = self.out(x) del pre_pools if with_output_feature_map: return x, output_feature_map else: return x ``` ## 四、成果展示 感覺是水了一篇 ![](https://hackmd.io/_uploads/S1TUA54d3.png) ## 五、結語 做了些小改動,感覺可以再做更多的實驗來比較結果,畢竟改動模型的複雜程度或亂加東西不一定能保證表現的提升與否。