# Notes on "[W-Net: A Deep Model for Fully Unsupervised Image Segmentation](https://arxiv.org/abs/1711.08506)"
###### tags: `notes` `segmentation` `unsupervised`
The aim is to use a deep network for **fully unsupervised segmentation**. With post-processing, segmentation output is further improved.
## Brief Outline
![Overview](https://i.imgur.com/yf4e3kW.png)
* Two U-Nets ([Ronneberger et. al. 2015](https://arxiv.org/abs/1505.04597)) are used together as a single autoencoder.
* The first U-Net encodes the input image into a k-way soft segmentation. The second U-Net reverses this process and produces a reconstructed image from the segmentation.
* Both the reconstruction error of the autoencoder as well as a soft normalized cut ([Shi et. al. 2000](https://dl.acm.org/citation.cfm?id=351611)) loss function on the segmented output are jointly optimized.
* To improve the segmentation, postprocessing is done in 2 steps:
* fully connected conditional random field (CRF) ([Krahenbuhl et. al. 2012](https://arxiv.org/abs/1210.5644), [Chen. et. al. 2017](https://arxiv.org/abs/1606.00915)) smoothing on the output segments.
* hierarchical merging method ([Arbelaez et. al. 2011](https://ieeexplore.ieee.org/document/5557884)) to obtain the final segmentation.
## Network Architecture
![W-Net Architecture](https://i.imgur.com/9SCjqWd.png)
* The W-Net architecture has 46 convolutional layers which are structured into 18 modules marked with the red rectangles.
* Each module consists of two 3x3 convolutional layers, each followed by a ReLU non-linearity and batch normalization. The first nine modules form the dense prediction base of the network and the second 9 correspond to the reconstruction decoder.
* The encoder $U_{Enc}$ has a contracting path (1st half) to capture context and a corresponding expansive path (2nd half) which enables precise localization (as in the original U-Net paper).
* Skip connections from contracting path to expansive path (as in original U-Net paper) are used to recover lost spatial information due to downsampling.
* The architecture of $U_{Dec}$ is similar to $U_{Enc}$, except that it reads the output of the $U_{Enc}$ which has different number of channels (w.r.t. input).
* One major modification is the use of depthwise separable convolution layers ([Chollet 2016](https://arxiv.org/abs/1610.02357)) in all modules except 1, 9, 10, and 18. A [good explanation](https://towardsdatascience.com/a-basic-introduction-to-separable-convolutions-b99ec3102728) of depthwise separable convolutions.
* The idea behind such an operation is to examine spatial correlations and cross-channel correlations independently:
* A depthwise convolution performs spatial convolutions independently over each channel .
* Then, a pointwise convolution projects the feature channels by the depthwise convolution onto a new channel space.
* The network does not include any fully connected layers which allows it to process arbitrarily sized images and make a segmentation prediction of the corresponding size.
### Soft Normalized Cut Loss
* The output of $U_{Enc}$ is a normalized $H \times W \times K$ dense prediction i.e. $K$-class pixel-wise softmax prediction. By taking the **argmax**, we can obtain a class prediction for each pixel.
* Normalized cut ($Ncut$) is computed as a global criterion for segmentation:
$$
Ncut_K(V) = \sum_{k = 1}^K\frac{cut(A_k, V - A_k)}{assoc(A_k, V)}\\
= \sum_{k = 1}^K \frac{\sum_{u \in A_k, v \in V - A_k}w(u, v)}{\sum_{u \in A_k, t \in V}w(u, t)}
$$
where $A_k$ is set of pixels in segment $k$, $V$ is the set of all pixels and $w$ measures the weight between two pixels.
* **Intuition for above equations** ([useful tutorial](http://www.sci.utah.edu/~gerig/CS7960-S2010/handouts/Normalized%20Graph%20cuts.pdf), also refer to the Normalized Cut original paper):
* The problem is considered as a weighted graph i.e. pixels are nodes of the graph and weights b/w every 2 pixels (nodes) are available.
* The outer summation ($k = 1$ to $K$) is a summation over all the K segments i.e. one segment is considered at a time in the inner terms.
* Numerator of inner term represents the sum of weights b/w pixels of the particular segment and pixels not in the particular segment.
* Denominator of inner term represents the sum of weights b/w pixels of the particular segment and all pixels in the image.
* Why normalized? Since we divide by the sum of weights b/w segment pixels and all pixels, it is a normalized calculation. Using only the numerator term is also an option, but it is not normalized. That doesn't work well (see the above tutorial link for understanding).
* But, the argmax function is not differentiable so backprop will not work. So, a soft version of the $Ncut$ loss is defined:
$$
J_{soft-Ncut}(V, K) = \sum_{k = 1}^K\frac{cut(A_k, V - A_k)}{assoc(A_k, V)} \\
= K - \sum_{k = 1}^K\frac{assoc(A_k, A_k)}{assoc(A_k, V)} \\
= K - \sum_{k = 1}^K\frac{\sum_{u \in V, v \in V}w(u, v)p(u = A_k)p(v = A_k)}{\sum_{u, t \in V}w(u, t)p(u = A_k)} \\
= K - \sum_{k = 1}^K\frac{\sum_{u \in V}p(u = A_k)\sum_{v \in V}w(u, v)p(v = A_k)}{\sum_{u \in V}p(u = A_k)\sum_{t \in V}w(u, t)}
$$
* **Intuition for above equations:**
* The earlier equations had $u \in A_k$ and $v \in A_k$ terms. Those require an argmax calculation i.e. we should know absolutely which pixel belongs to which class.
* However, we have a probability distribution (softmax), so we change the $u \in A_k$ to $u \in V$ while adding a $p(u = A_k)$ term (same for $v$).
* Thus, this change results in a soft version of the Normalized Cut function.
* In the last 2 equations, the summations are simply split depending on the variable in question.
* The inner numerator effectively computes the association between pixels of the same segment. The overall fraction is the normalized association. Thus, we would like to maximize this quantity.
* So, the negative of normalized association has to be minimized. Having that (negative) term added to $K$ won't make a difference to the optimization problem.
* No additional computation of probability terms is required, since those are computed by the encoder $U_{Enc}$ by default.
* By training $U_{Enc}$ to minimize the $J_{soft-Ncut}$ loss, we can simultaneously minimize the total normalized dis-association between the groups and maximize the total normalized association within the groups.
* Regarding the weights $w(u, v)$: **(Not mentioned explicitly in this paper)** (Need to refer to original paper on Normalized Cut)
* For the monocular case, the graph is constructed by taking each pixel as a node and define the edge weight $w(i, j)$ between node $i$ and $j$ as the product of a feature similarity term and spatial proximity term:
$$
w(i, j) = e^{\frac{-||F(i) - F(j)||_2^2}{\sigma_I}} * \begin{cases}
e^{\frac{-||X(i) - X(j)||_2^2}{\sigma_X}} & \text{if $||X(i) - X(j)||_2 < r$} \\
0 & \text{otherwise}
\end{cases}
$$
where $X(i)$ is the spatial location of node $i$, and $F(i)$ is a feature vector based on intensity, color, or texture information at that node (for various definitions of $F(i)$, refer to the original paper, here mostly $F(i)$ can be taken as the intensity of $i$).
* Note that weight is $0$ for any nodes more than $r$ pixels apart. Also note that $\sigma_I$, $\sigma_X$ and $r$ are set manually (so they are hyperparameters).
### Reconstruction Loss
* Similar to classical encoder-decoder architecture, reconstruction loss needs to be minimized to enforce that encoded information contain as much information of the original inputs as possible.
* Minimizing the reconstruction loss in the W-Net architecture makes the segmentation prediction align better with the input images.
* The loss is given by:
$$
J_{reconstr} = ||X - U_{Dec}(U_{Enc}(X; W_{Enc}); W_{Dec})||_2^2
$$
where $W_{Enc}$ and $W_{Dec}$ are the parameters of encoder and decoder respectively, and $X$ is the input image.
## Optimization
* By iteratively applying $J_{reconstr}$ and $J_{soft-Ncut}$, the network balances the trade-off between the accuracy of reconstruction and the consistency in the encoded representation layer.
* The algorithm is shown below:
![W-Net mini-batch SGD training](https://i.imgur.com/Anc2ShE.png)
## Postprocessing
Don't understand these parts (CRF and hierarchical clustering) yet. Hopefully will be able to update this someday :cry:.