## Defininng residual conv block ``` def res_conv_block(x, filter_size, size, dropout_rate, batch_norm=False): ''' Residual convolutional layer. Two variants.... Either put activation function before the addition with shortcut or after the addition (which would be as proposed in the original resNet). 1. conv - BN - Activation - conv - BN - Activation - shortcut - BN - shortcut+BN 2. conv - BN - Activation - conv - BN - shortcut - BN - shortcut+BN - Activation Check fig 4 in https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf ''' conv = Conv3D(size, (filter_size, filter_size, filter_size), kernel_initializer=kernel_initializer, padding='same')(x) if batch_norm is True: conv = BatchNormalization()(conv) conv = Activation('relu')(conv) conv = Conv3D(size, (filter_size, filter_size, filter_size), padding='same')(conv) if batch_norm is True: conv = BatchNormalization()(conv) #conv = layers.Activation('relu')(conv) #Activation before addition with shortcut if dropout_rate > 0: conv = Dropout(dropout_rate)(conv) shortcut = Conv3D(size, kernel_size=(1, 1, 1), padding='same')(x) if batch_norm is True: shortcut = BatchNormalization()(shortcut) res_path = tf.keras.layers.add([shortcut, conv]) # Activation after addition with shortcut (Original residual block) res_path = Activation('relu')(res_path) return res_path def expend_as(tensor, rep): # Anonymous lambda function to expand the specified axis by a factor of argument, rep. # If tensor has shape (512,512,N), lambda will return a tensor of shape (512,512,N*rep), if specified axis=2 my_repeat = Lambda(lambda x, repnum: tf.keras.backend.repeat_elements(x, repnum, axis=4), arguments={'repnum': rep})(tensor) return my_repeat def AttnGatingBlock(x, g, inter_shape): shape_x = int_shape(x) shape_g = int_shape(g) # Getting the gating signal to the same number of filters as the inter_shape phi_g = Conv3D(filters=inter_shape, kernel_size=1, strides=1, padding='same')(g) # Getting the x signal to the same shape as the gating signal theta_x = Conv3D(filters=inter_shape, kernel_size=3, strides=(shape_x[1] // shape_g[1], shape_x[2] // shape_g[2], shape_x[3] // shape_g[3]), padding='same')(x) # Element-wise addition of the gating and x signals add_xg = tf.keras.layers.add([phi_g, theta_x]) add_xg = Activation('relu')(add_xg) # 1x1x1 convolution psi = Conv3D(filters=1, kernel_size=1, padding='same')(add_xg) psi = Activation('sigmoid')(psi) shape_sigmoid = int_shape(psi) # Upsampling psi back to the original dimensions of x signal upsample_sigmoid_xg = UpSampling3D( size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2], shape_x[3] // shape_sigmoid[3]))(psi) # Expanding the filter axis to the number of filters in the original x signal upsample_sigmoid_xg = expend_as(upsample_sigmoid_xg, shape_x[4]) # Element-wise multiplication of attention coefficients back onto original x signal attn_coefficients = multiply([upsample_sigmoid_xg, x]) # Final 1x1x1 convolution to consolidate attention signal to original x dimensions output = Conv3D(filters=shape_x[4], kernel_size=1, strides=1, padding='same')(attn_coefficients) output = BatchNormalization()(output) return output def gating_signal(input, out_size, batch_norm=False): """ resize the down layer feature map into the same dimension as the up layer feature map using 1x1 conv :return: the gating feature map with the same dimension of the up layer feature map """ x = tf.keras.layers.Conv3D(out_size, (1, 1, 1), padding='same')(input) if batch_norm: x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.Activation('relu')(x) return x ``` ## Defining the residual attention 3d unet ``` kernel_initializer = 'he_uniform' #Try others if you want def Attention_ResUNet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True): ''' Rsidual UNet, with attention ''' # network structure FILTER_NUM = 16 # number of basic filters for the first layer FILTER_SIZE = 3 # size of the convolutional filter UP_SAMP_SIZE = 2 # size of upsampling filters # input data inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS)) # Downsampling layers # DownRes 1, double residual convolution + pooling c1 = res_conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate=0.1, batch_norm=True) p1 = MaxPooling3D(pool_size=(2,2,2))(c1) # DownRes 2 c2 = res_conv_block(p1, FILTER_SIZE, 2*FILTER_NUM, dropout_rate=0.1, batch_norm=True) p2 = MaxPooling3D(pool_size=(2,2,2))(c2) # DownRes 3 c3 = res_conv_block(p2, FILTER_SIZE, 4*FILTER_NUM, dropout_rate=0.2, batch_norm=True) p3 = MaxPooling3D(pool_size=(2,2,2))(c3) # DownRes 4 c4 = res_conv_block(p3, FILTER_SIZE, 8*FILTER_NUM, dropout_rate=0.2, batch_norm=True) p4 = MaxPooling3D(pool_size=(2,2,2))(c4) ## Bridge # DownRes 5, convolution only c5 = res_conv_block(p4, FILTER_SIZE, 16*FILTER_NUM, dropout_rate=0.3, batch_norm=True) # Expansive path/Upsampling layers # UpRes 6, attention gated concatenation + upsampling + double residual convolution gating_0 = gating_signal(c5, 8*FILTER_NUM, batch_norm) att_0 = AttnGatingBlock(c4, gating_0, 8*FILTER_NUM) u6 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c5) u6 = concatenate([u6, att_0]) c6 = res_conv_block(u6, FILTER_SIZE, 8*FILTER_NUM, dropout_rate=0.2, batch_norm=True) # UpRes 7 gating_1 = gating_signal(c6, 4*FILTER_NUM, batch_norm) att_1 = AttnGatingBlock(c3, gating_1, 4*FILTER_NUM) u7 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c6) u7 = concatenate([u7, att_1]) c7 = res_conv_block(u7, FILTER_SIZE, 4*FILTER_NUM, dropout_rate=0.2, batch_norm=True) # UpRes 8 gating_2= gating_signal(c7, 4*FILTER_NUM, batch_norm) att_2 = AttnGatingBlock(c2, gating_2, 4*FILTER_NUM) u8 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c7) u8 = concatenate([u8, att_2]) c8 = res_conv_block(u8, FILTER_SIZE, 2*FILTER_NUM, dropout_rate=0.1, batch_norm=True) # UpRes 9 gating_3 = gating_signal(c8, FILTER_NUM, batch_norm) att_3 = AttnGatingBlock(c1, gating_3, FILTER_NUM) u9 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c8) u9 = concatenate([u9, att_3]) c9= res_conv_block(u9, FILTER_SIZE, FILTER_NUM, dropout_rate=0.1, batch_norm=True) # 1*1 convolutional layers conv_final = Conv3D(NUM_CLASSES, kernel_size=(1,1,1))(c9) conv_final = tf.keras.layers.BatchNormalization()(conv_final) conv_final = tf.keras.layers.Activation('softmax')(conv_final) #Change to softmax for multichannel # Model integration model = Model(inputs=[inputs], outputs=[conv_final]) return model ``` ## Model summary ================================================================================================== Total params: 6,847,720 Trainable params: 6,842,272 Non-trainable params: 5,448 ![](https://i.imgur.com/sfYNq6T.jpg)