Try   HackMD

Python function Explanation & Verification

由於我們需要知道 python 以及 Unet 背後真正的運算方式是什麼,因此要對 Unet 中有用到一些比較特別的 API 去做驗證以及理解

torchvision.transforms.functional.center_crop

會用在下圖中 灰色箭頭(copy and crop) 的運算,此 function 即是從圖片(from Contracting path of Unet)正中心向外擷取所需的圖片範圍,因此我們必須算出該從圖片的哪個 pixel 開始擷取並擷取至哪裡

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • verification
# -*- coding: utf-8 -*- """ Created on Thu Sep 14 20:48:00 2023 @author: Jay """ import torchvision import torch import numpy as np from PIL import Image channel = 9 input_width = 64 # assume height = width output_width = 56 torchInputArray = torch.zeros(channel, input_width, input_width, dtype=torch.int32) sizeList = list( torchInputArray.size() ) # initial the input array num = 1 for i in range(sizeList[0]): for j in range(sizeList[1]): for k in range(sizeList[2]): torchInputArray[i, j, k] = num num += 1 torchResult = torchvision.transforms.functional.center_crop( torchInputArray, (output_width, output_width) ) # into numpy torchInputArray_np = torchInputArray.numpy() torchResult_np = torchResult.numpy() # ========== handcraft func ========== myResult_np = np.zeros([channel, output_width, output_width], dtype=np.int32) start_addr = ((input_width-1)//2) - (output_width - 2)//2 for i in range(channel): for j in range(start_addr, start_addr+output_width): for k in range(start_addr, start_addr+output_width): myResult_np[i, (j-start_addr), (k-start_addr)] = torchInputArray_np[i, j, k] if((myResult_np == torchResult_np).all()): print("function Correct!") else: print("Failed...........") left_up = (start_addr, start_addr) right_down = (start_addr+output_width-1, start_addr+output_width-1) print(left_up, right_down) # ========== coordinate generator ========== def coordinate_Gen(sizeIn): (input_PicWidth, out_PicWidth) = sizeIn #print('input_PicWidth, out_PicWidth = ', input_PicWidth, out_PicWidth) start_addr = ((input_PicWidth-1)//2) - (out_PicWidth - 2)//2 left_up = (start_addr, start_addr) right_down = (start_addr+out_PicWidth-1, start_addr+out_PicWidth-1) return [left_up, right_down] Unet_Layer_copy_and_crop = { 'Layer1':(568, 392), 'Layer2':(280, 200), 'Layer3':(136, 104), 'Layer4':(64, 56), } for Layer in range(len(Unet_Layer_copy_and_crop)): print("Layer" + str(Layer) + ", [pixel_addr(left_up), pixel_addr(right_down)] = ", end='') print( coordinate_Gen(Unet_Layer_copy_and_crop['Layer' + str(Layer+1)]) )

-> 實際去 call torchvision.transforms.functional.center_crop 函式並驗證擷取範圍是否一致 -> Pass!

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • Result explain :
    以Layer0 為例子,輸出的
    392392
    應該為原圖
    568568
    的 (y,x) = (8, 8) 位址往右擷取到 (8, 479),再跳行,以此類推依序到 (y,x) = (479, 479)
-> 此 code 只有保證在圖片寬、高為 even 時能 work !

torch.nn.ConvTranspose2d

會用在上圖中 綠色箭頭(up-conv 2x2) 的運算,此 function 即是對輸入圖像做 de-convolution
  • 一般 convolution v.s. de-convolution(full convolution)
    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →
# -*- coding: utf-8 -*- """ Created on Fri Sep 15 20:28:21 2023 @author: Jay """ import torch import torchvision.transforms.functional from torch import nn from scipy.signal import convolve2d import numpy as np # initial array inputArray_np = np.array([i+1 for i in range(5*5)]) inputArray_np = np.reshape(inputArray_np, (5, 5)) kernel = np.array([i+1 for i in range(3*3)]) kernel = np.reshape(kernel, (3, 3)) # normal conv normalConv_result = convolve2d(inputArray_np, kernel[::-1, ::-1], mode='valid') # de-conv method1 dConv_result1 = convolve2d(normalConv_result, kernel[::-1, ::-1], mode='full') # de-conv method2 normalConv_result_pad = np.pad(normalConv_result, 2, 'constant') dConv_result2 = convolve2d(normalConv_result_pad, kernel[::-1, ::-1], mode='valid') if((dConv_result1 == dConv_result2).all()): print("same!!") else: print("different....")

-> 看不懂我上面在畫啥的也可以參考這邊

  • verification
    ConvTranspose2d visualization
    ConvTranspose2d formula:
    Hout=(Hin1)×stride[0]2×padding[0]+dilation[0]×(kernelsize[0]1)+outputpadding[0]+1

    assume: dilation = 1, padding = 0, stride = 1, the formula becomes:
    Hout=Hin+kernelsize[0]1
    , width vise versa.
# -*- coding: utf-8 -*- """ Created on Fri Sep 15 20:28:21 2023 @author: Jay """ import torch import torchvision.transforms.functional from torch import nn from scipy.signal import convolve2d import numpy as np input_channel = 3 output_channel = 10 kernel_size = 3 input_picture_size = 32 # Calculate by nn.ConvTranspose2d input_picture = torch.ones(input_channel, input_picture_size, input_picture_size) Layer = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=kernel_size, stride=1, padding=0, bias=False) print("input shape = ", input_picture.shape) output = Layer(input_picture) print("output shape = ", output.shape) weight = Layer.weight print("weight shape = ", weight.size()) # get the nparray of input/output and weight input_picture_np = input_picture.numpy() output_np = output.detach().numpy() weight_np = weight.detach().numpy() # # Calculate by myself to check nn.ConvTranspose2d myOutput = np.zeros([ output_channel, (input_picture_size+kernel_size-1), (input_picture_size+kernel_size-1) ]) for cout in range(output_channel): outTmp = np.zeros([ (input_picture_size+kernel_size-1), (input_picture_size+kernel_size-1) ]) for cin in range(input_channel): outTmp += convolve2d(input_picture_np[cin, ], weight_np[cin, cout, :, :], mode='full') myOutput[cout, ] = outTmp.copy() error = myOutput - output_np if(np.max(error) < 1e-7):# exist mini difference print("same!!") else: print("different....")

Note : nn 提取出來的 weight 貌似有倒過來,例如下圖


up-conv 2x2

上面稍微解釋了 de-convolution 的運算方式,但上面的例子 stride 為 1,而 Unet 中的 de-conv

22 是用
22
的 kernel 搭配 stride=2,所以這邊要再解釋一下 de-conv 的 stride 參數

de-conv 的 stride 參數有別於一般 conv 的 stride ,若一般 conv 的 stride 設為 1 ,則表示 kernel 在 input 上 sliding window 的移動步伐間距為 1,而 de-conv 的 stride 參數則是表示要將輸入圖片的每兩個 pixel 之間填入 0,而填入數量就是 stride - 1,而 sliding window 的移動步伐間距依舊保持為 1,如此一來達成輸出圖片長寬為輸入圖片2倍的目的。

Example:
input size:

44
kernel size:
22

stride =
2

output =

88

  • verification
# -*- coding: utf-8 -*- """ Created on Fri Sep 15 20:28:21 2023 @author: Jay """ import torch import torchvision.transforms.functional from torch import nn from scipy.signal import convolve2d import numpy as np #================================================= # stride = 2 #================================================= input_channel = 3 output_channel = 10 kernel_size = 2 input_picture_size = 32 # Calculate by nn.ConvTranspose2d input_picture = torch.ones(input_channel, input_picture_size, input_picture_size) Layer = nn.ConvTranspose2d(input_channel, output_channel, kernel_size=kernel_size, stride=2, padding=0, bias=False) print("input shape = ", input_picture.shape) output = Layer(input_picture) print("output shape = ", output.shape) weight = Layer.weight print("weight shape = ", weight.size()) # get the nparray of input/output and weight input_picture_np = input_picture.numpy() output_np = output.detach().numpy() weight_np = weight.detach().numpy() # input expansion input_picture_tmp_np = np.zeros([input_channel, (input_picture_size*2-1), (input_picture_size*2-1)]) for ic in range(input_channel): for y in range(input_picture_np.shape[1]): for x in range(input_picture_np.shape[2]): input_picture_tmp_np[ic, y*2, x*2] = input_picture_np[ic, y, x] # Calculate by myself to check nn.ConvTranspose2d myOutput = np.zeros([ output_channel, (input_picture_size*2), (input_picture_size*2) ]) for cout in range(output_channel): outTmp = np.zeros([ (input_picture_size*2), (input_picture_size*2) ]) for cin in range(input_channel): outTmp += convolve2d(input_picture_tmp_np[cin, ], weight_np[cin, cout, :, :], mode='full') myOutput[cout, ] = outTmp.copy() error = myOutput - output_np if(np.max(error) < 1e-7):# exist mini difference print("same!!") else: print("different....")