## Defininng residual conv block
```
def res_conv_block(x, filter_size, size, dropout_rate, batch_norm=False):
'''
Residual convolutional layer.
Two variants....
Either put activation function before the addition with shortcut
or after the addition (which would be as proposed in the original resNet).
1. conv - BN - Activation - conv - BN - Activation
- shortcut - BN - shortcut+BN
2. conv - BN - Activation - conv - BN
- shortcut - BN - shortcut+BN - Activation
Check fig 4 in https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf
'''
conv = Conv3D(size, (filter_size, filter_size, filter_size), kernel_initializer=kernel_initializer, padding='same')(x)
if batch_norm is True:
conv = BatchNormalization()(conv)
conv = Activation('relu')(conv)
conv = Conv3D(size, (filter_size, filter_size, filter_size),
padding='same')(conv)
if batch_norm is True:
conv = BatchNormalization()(conv)
#conv = layers.Activation('relu')(conv) #Activation before addition with shortcut
if dropout_rate > 0:
conv = Dropout(dropout_rate)(conv)
shortcut = Conv3D(size, kernel_size=(1, 1, 1), padding='same')(x)
if batch_norm is True:
shortcut = BatchNormalization()(shortcut)
res_path = tf.keras.layers.add([shortcut, conv])
# Activation after addition with shortcut (Original residual block)
res_path = Activation('relu')(res_path)
return res_path
def expend_as(tensor, rep):
# Anonymous lambda function to expand the specified axis by a factor of argument, rep.
# If tensor has shape (512,512,N), lambda will return a tensor of shape (512,512,N*rep), if specified axis=2
my_repeat = Lambda(lambda x, repnum: tf.keras.backend.repeat_elements(x, repnum, axis=4),
arguments={'repnum': rep})(tensor)
return my_repeat
def AttnGatingBlock(x, g, inter_shape):
shape_x = int_shape(x)
shape_g = int_shape(g)
# Getting the gating signal to the same number of filters as the inter_shape
phi_g = Conv3D(filters=inter_shape,
kernel_size=1,
strides=1,
padding='same')(g)
# Getting the x signal to the same shape as the gating signal
theta_x = Conv3D(filters=inter_shape,
kernel_size=3,
strides=(shape_x[1] // shape_g[1],
shape_x[2] // shape_g[2],
shape_x[3] // shape_g[3]),
padding='same')(x)
# Element-wise addition of the gating and x signals
add_xg = tf.keras.layers.add([phi_g, theta_x])
add_xg = Activation('relu')(add_xg)
# 1x1x1 convolution
psi = Conv3D(filters=1, kernel_size=1, padding='same')(add_xg)
psi = Activation('sigmoid')(psi)
shape_sigmoid = int_shape(psi)
# Upsampling psi back to the original dimensions of x signal
upsample_sigmoid_xg = UpSampling3D(
size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2],
shape_x[3] // shape_sigmoid[3]))(psi)
# Expanding the filter axis to the number of filters in the original x signal
upsample_sigmoid_xg = expend_as(upsample_sigmoid_xg, shape_x[4])
# Element-wise multiplication of attention coefficients back onto original x signal
attn_coefficients = multiply([upsample_sigmoid_xg, x])
# Final 1x1x1 convolution to consolidate attention signal to original x dimensions
output = Conv3D(filters=shape_x[4],
kernel_size=1,
strides=1,
padding='same')(attn_coefficients)
output = BatchNormalization()(output)
return output
def gating_signal(input, out_size, batch_norm=False):
"""
resize the down layer feature map into the same dimension as the up layer feature map
using 1x1 conv
:return: the gating feature map with the same dimension of the up layer feature map
"""
x = tf.keras.layers.Conv3D(out_size, (1, 1, 1), padding='same')(input)
if batch_norm:
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
return x
```
## Defining the residual attention 3d unet
```
kernel_initializer = 'he_uniform' #Try others if you want
def Attention_ResUNet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
Rsidual UNet, with attention
'''
# network structure
FILTER_NUM = 16 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters
# input data
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))
# Downsampling layers
# DownRes 1, double residual convolution + pooling
c1 = res_conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate=0.1, batch_norm=True)
p1 = MaxPooling3D(pool_size=(2,2,2))(c1)
# DownRes 2
c2 = res_conv_block(p1, FILTER_SIZE, 2*FILTER_NUM, dropout_rate=0.1, batch_norm=True)
p2 = MaxPooling3D(pool_size=(2,2,2))(c2)
# DownRes 3
c3 = res_conv_block(p2, FILTER_SIZE, 4*FILTER_NUM, dropout_rate=0.2, batch_norm=True)
p3 = MaxPooling3D(pool_size=(2,2,2))(c3)
# DownRes 4
c4 = res_conv_block(p3, FILTER_SIZE, 8*FILTER_NUM, dropout_rate=0.2, batch_norm=True)
p4 = MaxPooling3D(pool_size=(2,2,2))(c4)
## Bridge
# DownRes 5, convolution only
c5 = res_conv_block(p4, FILTER_SIZE, 16*FILTER_NUM, dropout_rate=0.3, batch_norm=True)
# Expansive path/Upsampling layers
# UpRes 6, attention gated concatenation + upsampling + double residual convolution
gating_0 = gating_signal(c5, 8*FILTER_NUM, batch_norm)
att_0 = AttnGatingBlock(c4, gating_0, 8*FILTER_NUM)
u6 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c5)
u6 = concatenate([u6, att_0])
c6 = res_conv_block(u6, FILTER_SIZE, 8*FILTER_NUM,
dropout_rate=0.2, batch_norm=True)
# UpRes 7
gating_1 = gating_signal(c6, 4*FILTER_NUM, batch_norm)
att_1 = AttnGatingBlock(c3, gating_1, 4*FILTER_NUM)
u7 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c6)
u7 = concatenate([u7, att_1])
c7 = res_conv_block(u7, FILTER_SIZE, 4*FILTER_NUM, dropout_rate=0.2, batch_norm=True)
# UpRes 8
gating_2= gating_signal(c7, 4*FILTER_NUM, batch_norm)
att_2 = AttnGatingBlock(c2, gating_2, 4*FILTER_NUM)
u8 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c7)
u8 = concatenate([u8, att_2])
c8 = res_conv_block(u8, FILTER_SIZE, 2*FILTER_NUM,
dropout_rate=0.1, batch_norm=True)
# UpRes 9
gating_3 = gating_signal(c8, FILTER_NUM, batch_norm)
att_3 = AttnGatingBlock(c1, gating_3, FILTER_NUM)
u9 = Conv3DTranspose(8*FILTER_NUM, (UP_SAMP_SIZE, UP_SAMP_SIZE, UP_SAMP_SIZE), strides=(2, 2, 2), padding='same')(c8)
u9 = concatenate([u9, att_3])
c9= res_conv_block(u9, FILTER_SIZE, FILTER_NUM, dropout_rate=0.1, batch_norm=True)
# 1*1 convolutional layers
conv_final = Conv3D(NUM_CLASSES, kernel_size=(1,1,1))(c9)
conv_final = tf.keras.layers.BatchNormalization()(conv_final)
conv_final = tf.keras.layers.Activation('softmax')(conv_final) #Change to softmax for multichannel
# Model integration
model = Model(inputs=[inputs], outputs=[conv_final])
return model
```
## Model summary
==================================================================================================
Total params: 6,847,720
Trainable params: 6,842,272
Non-trainable params: 5,448
