The first U-Net encodes the input image into a k-way soft segmentation. The second U-Net reverses this process and produces a reconstructed image from the segmentation.
Both the reconstruction error of the autoencoder as well as a soft normalized cut (Shi et. al. 2000) loss function on the segmented output are jointly optimized.
To improve the segmentation, postprocessing is done in 2 steps:
The W-Net architecture has 46 convolutional layers which are structured into 18 modules marked with the red rectangles.
Each module consists of two 3x3 convolutional layers, each followed by a ReLU non-linearity and batch normalization. The first nine modules form the dense prediction base of the network and the second 9 correspond to the reconstruction decoder.
The encoder has a contracting path (1st half) to capture context and a corresponding expansive path (2nd half) which enables precise localization (as in the original U-Net paper).
Skip connections from contracting path to expansive path (as in original U-Net paper) are used to recover lost spatial information due to downsampling.
The architecture of is similar to , except that it reads the output of the which has different number of channels (w.r.t. input).
One major modification is the use of depthwise separable convolution layers (Chollet 2016) in all modules except 1, 9, 10, and 18. A good explanation of depthwise separable convolutions.
The idea behind such an operation is to examine spatial correlations and cross-channel correlations independently:
A depthwise convolution performs spatial convolutions independently over each channel .
Then, a pointwise convolution projects the feature channels by the depthwise convolution onto a new channel space.
The network does not include any fully connected layers which allows it to process arbitrarily sized images and make a segmentation prediction of the corresponding size.
Soft Normalized Cut Loss
The output of is a normalized dense prediction i.e. -class pixel-wise softmax prediction. By taking the argmax, we can obtain a class prediction for each pixel.
Normalized cut () is computed as a global criterion for segmentation: where is set of pixels in segment , is the set of all pixels and measures the weight between two pixels.
Intuition for above equations (useful tutorial, also refer to the Normalized Cut original paper):
The problem is considered as a weighted graph i.e. pixels are nodes of the graph and weights b/w every 2 pixels (nodes) are available.
The outer summation ( to ) is a summation over all the K segments i.e. one segment is considered at a time in the inner terms.
Numerator of inner term represents the sum of weights b/w pixels of the particular segment and pixels not in the particular segment.
Denominator of inner term represents the sum of weights b/w pixels of the particular segment and all pixels in the image.
Why normalized? Since we divide by the sum of weights b/w segment pixels and all pixels, it is a normalized calculation. Using only the numerator term is also an option, but it is not normalized. That doesn't work well (see the above tutorial link for understanding).
But, the argmax function is not differentiable so backprop will not work. So, a soft version of the loss is defined:
Intuition for above equations:
The earlier equations had and terms. Those require an argmax calculation i.e. we should know absolutely which pixel belongs to which class.
However, we have a probability distribution (softmax), so we change the to while adding a term (same for ).
Thus, this change results in a soft version of the Normalized Cut function.
In the last 2 equations, the summations are simply split depending on the variable in question.
The inner numerator effectively computes the association between pixels of the same segment. The overall fraction is the normalized association. Thus, we would like to maximize this quantity.
So, the negative of normalized association has to be minimized. Having that (negative) term added to won't make a difference to the optimization problem.
No additional computation of probability terms is required, since those are computed by the encoder by default.
By training to minimize the loss, we can simultaneously minimize the total normalized dis-association between the groups and maximize the total normalized association within the groups.
Regarding the weights : (Not mentioned explicitly in this paper) (Need to refer to original paper on Normalized Cut)
For the monocular case, the graph is constructed by taking each pixel as a node and define the edge weight between node and as the product of a feature similarity term and spatial proximity term: where is the spatial location of node , and is a feature vector based on intensity, color, or texture information at that node (for various definitions of , refer to the original paper, here mostly can be taken as the intensity of ).
Note that weight is for any nodes more than pixels apart. Also note that , and are set manually (so they are hyperparameters).
Reconstruction Loss
Similar to classical encoder-decoder architecture, reconstruction loss needs to be minimized to enforce that encoded information contain as much information of the original inputs as possible.
Minimizing the reconstruction loss in the W-Net architecture makes the segmentation prediction align better with the input images.
The loss is given by: where and are the parameters of encoder and decoder respectively, and is the input image.
Optimization
By iteratively applying and , the network balances the trade-off between the accuracy of reconstruction and the consistency in the encoded representation layer.