# Python function Explanation & Verification 由於我們需要知道 python 以及 Unet 背後真正的運算方式是什麼,因此要對 Unet 中有用到一些比較特別的 API 去做驗證以及理解 ## torchvision.transforms.functional.center_crop <font color="#f00"> 會用在下圖中<font color="#999"> 灰色箭頭(copy and crop) </font>的運算,此 function 即是從圖片(from Contracting path of Unet)正中心向外擷取所需的圖片範圍,因此我們必須算出該從圖片的哪個 pixel 開始擷取並擷取至哪裡 </font> ![](https://hackmd.io/_uploads/SkNtr6Zya.png) - verification ```python= # -*- coding: utf-8 -*- """ Created on Thu Sep 14 20:48:00 2023 @author: Jay """ import torchvision import torch import numpy as np from PIL import Image channel = 9 input_width = 64 # assume height = width output_width = 56 torchInputArray = torch.zeros(channel, input_width, input_width, dtype=torch.int32) sizeList = list( torchInputArray.size() ) # initial the input array num = 1 for i in range(sizeList[0]): for j in range(sizeList[1]): for k in range(sizeList[2]): torchInputArray[i, j, k] = num num += 1 torchResult = torchvision.transforms.functional.center_crop( torchInputArray, (output_width, output_width) ) # into numpy torchInputArray_np = torchInputArray.numpy() torchResult_np = torchResult.numpy() # ========== handcraft func ========== myResult_np = np.zeros([channel, output_width, output_width], dtype=np.int32) start_addr = ((input_width-1)//2) - (output_width - 2)//2 for i in range(channel): for j in range(start_addr, start_addr+output_width): for k in range(start_addr, start_addr+output_width): myResult_np[i, (j-start_addr), (k-start_addr)] = torchInputArray_np[i, j, k] if((myResult_np == torchResult_np).all()): print("function Correct!") else: print("Failed...........") left_up = (start_addr, start_addr) right_down = (start_addr+output_width-1, start_addr+output_width-1) print(left_up, right_down) # ========== coordinate generator ========== def coordinate_Gen(sizeIn): (input_PicWidth, out_PicWidth) = sizeIn #print('input_PicWidth, out_PicWidth = ', input_PicWidth, out_PicWidth) start_addr = ((input_PicWidth-1)//2) - (out_PicWidth - 2)//2 left_up = (start_addr, start_addr) right_down = (start_addr+out_PicWidth-1, start_addr+out_PicWidth-1) return [left_up, right_down] Unet_Layer_copy_and_crop = { 'Layer1':(568, 392), 'Layer2':(280, 200), 'Layer3':(136, 104), 'Layer4':(64, 56), } for Layer in range(len(Unet_Layer_copy_and_crop)): print("Layer" + str(Layer) + ", [pixel_addr(left_up), pixel_addr(right_down)] = ", end='') print( coordinate_Gen(Unet_Layer_copy_and_crop['Layer' + str(Layer+1)]) ) ``` -> 實際去 call torchvision.transforms.functional.center_crop 函式並驗證擷取範圍是否一致 -> Pass! ![](https://hackmd.io/_uploads/BkgZ_pZy6.png) - Result explain : 以Layer0 為例子,輸出的 $392*392$ 應該為原圖 $568*568$ 的 (y,x) = (8, 8) 位址往右擷取到 (8, 479),再跳行,以此類推依序到 (y,x) = (479, 479) <font color="#fff"> -> 此 code 只有保證在圖片寬、高為 even 時能 work ! </font> --- ## torch.nn.ConvTranspose2d <font color="#f00"> 會用在上圖中 <font color="#0b0"> 綠色箭頭(up-conv 2x2) </font>的運算,此 function 即是對輸入圖像做 de-convolution </font> - 一般 convolution v.s. de-convolution(full convolution) ![](https://hackmd.io/_uploads/rJ2euyfk6.png) ```python= # -*- coding: utf-8 -*- """ Created on Fri Sep 15 20:28:21 2023 @author: Jay """ import torch import torchvision.transforms.functional from torch import nn from scipy.signal import convolve2d import numpy as np # initial array inputArray_np = np.array([i+1 for i in range(5*5)]) inputArray_np = np.reshape(inputArray_np, (5, 5)) kernel = np.array([i+1 for i in range(3*3)]) kernel = np.reshape(kernel, (3, 3)) # normal conv normalConv_result = convolve2d(inputArray_np, kernel[::-1, ::-1], mode='valid') # de-conv method1 dConv_result1 = convolve2d(normalConv_result, kernel[::-1, ::-1], mode='full') # de-conv method2 normalConv_result_pad = np.pad(normalConv_result, 2, 'constant') dConv_result2 = convolve2d(normalConv_result_pad, kernel[::-1, ::-1], mode='valid') if((dConv_result1 == dConv_result2).all()): print("same!!") else: print("different....") ``` ![](https://hackmd.io/_uploads/rkK5K1MkT.png) -> [看不懂我上面在畫啥的也可以參考這邊](https://blog.csdn.net/qq_27261889/article/details/86304061) - verification [ConvTranspose2d visualization](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) ConvTranspose2d formula: $Hout =(Hin −1)×stride[0]−2×padding[0]+dilation[0]×(kernelsize[0]−1)+output_padding[0]+1$ assume: dilation = 1, padding = 0, stride = 1, the formula becomes: $Hout =Hin+kernelsize[0]−1$, width vise versa. ```python= # -*- coding: utf-8 -*- """ Created on Fri Sep 15 20:28:21 2023 @author: Jay """ import torch import torchvision.transforms.functional from torch import nn from scipy.signal import convolve2d import numpy as np input_channel = 3 output_channel = 10 kernel_size = 3 input_picture_size = 32 # Calculate by nn.ConvTranspose2d input_picture = torch.ones(input_channel, input_picture_size, input_picture_size) Layer = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=kernel_size, stride=1, padding=0, bias=False) print("input shape = ", input_picture.shape) output = Layer(input_picture) print("output shape = ", output.shape) weight = Layer.weight print("weight shape = ", weight.size()) # get the nparray of input/output and weight input_picture_np = input_picture.numpy() output_np = output.detach().numpy() weight_np = weight.detach().numpy() # # Calculate by myself to check nn.ConvTranspose2d myOutput = np.zeros([ output_channel, (input_picture_size+kernel_size-1), (input_picture_size+kernel_size-1) ]) for cout in range(output_channel): outTmp = np.zeros([ (input_picture_size+kernel_size-1), (input_picture_size+kernel_size-1) ]) for cin in range(input_channel): outTmp += convolve2d(input_picture_np[cin, ], weight_np[cin, cout, :, :], mode='full') myOutput[cout, ] = outTmp.copy() error = myOutput - output_np if(np.max(error) < 1e-7):# exist mini difference print("same!!") else: print("different....") ``` ![](https://hackmd.io/_uploads/Sk-ubgfk6.png) <font color="#f00"> Note : nn 提取出來的 weight 貌似有倒過來,例如下圖 </font> ![](https://hackmd.io/_uploads/H1XBQlz1a.png) --- ### <font color="#0b0"> up-conv 2x2 </font> 上面稍微解釋了 de-convolution 的運算方式,但上面的例子 stride 為 1,而 Unet 中的 de-conv $2*2$ 是用 $2*2$ 的 kernel 搭配 stride=2,所以這邊要再解釋一下 de-conv 的 stride 參數 <font color="#f00"> de-conv 的 stride 參數有別於一般 conv 的 stride ,若一般 conv 的 stride 設為 1 ,則表示 kernel 在 input 上 sliding window 的移動步伐間距為 1,而 de-conv 的 stride 參數則是表示要將輸入圖片的每兩個 pixel 之間填入 0,而填入數量就是 stride - 1,而 sliding window 的移動步伐間距依舊保持為 1,如此一來達成輸出圖片長寬為輸入圖片2倍的目的。 </font> Example: input size: $4*4$ kernel size: $2*2$ stride = $2$ ![](https://hackmd.io/_uploads/S1AtC7z1a.png) output = $8*8$ - verification ```python= # -*- coding: utf-8 -*- """ Created on Fri Sep 15 20:28:21 2023 @author: Jay """ import torch import torchvision.transforms.functional from torch import nn from scipy.signal import convolve2d import numpy as np #================================================= # stride = 2 #================================================= input_channel = 3 output_channel = 10 kernel_size = 2 input_picture_size = 32 # Calculate by nn.ConvTranspose2d input_picture = torch.ones(input_channel, input_picture_size, input_picture_size) Layer = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=kernel_size, stride=2, padding=0, bias=False) print("input shape = ", input_picture.shape) output = Layer(input_picture) print("output shape = ", output.shape) weight = Layer.weight print("weight shape = ", weight.size()) # get the nparray of input/output and weight input_picture_np = input_picture.numpy() output_np = output.detach().numpy() weight_np = weight.detach().numpy() # input expansion input_picture_tmp_np = np.zeros([input_channel, (input_picture_size*2-1), (input_picture_size*2-1)]) for ic in range(input_channel): for y in range(input_picture_np.shape[1]): for x in range(input_picture_np.shape[2]): input_picture_tmp_np[ic, y*2, x*2] = input_picture_np[ic, y, x] # Calculate by myself to check nn.ConvTranspose2d myOutput = np.zeros([ output_channel, (input_picture_size*2), (input_picture_size*2) ]) for cout in range(output_channel): outTmp = np.zeros([ (input_picture_size*2), (input_picture_size*2) ]) for cin in range(input_channel): outTmp += convolve2d(input_picture_tmp_np[cin, ], weight_np[cin, cout, :, :], mode='full') myOutput[cout, ] = outTmp.copy() error = myOutput - output_np if(np.max(error) < 1e-7):# exist mini difference print("same!!") else: print("different....") ``` ![](https://hackmd.io/_uploads/HJqQ14fy6.png)