# Notes on "[W-Net: A Deep Model for Fully Unsupervised Image Segmentation](https://arxiv.org/abs/1711.08506)" ###### tags: `notes` `segmentation` `unsupervised` The aim is to use a deep network for **fully unsupervised segmentation**. With post-processing, segmentation output is further improved. ## Brief Outline ![Overview](https://i.imgur.com/yf4e3kW.png) * Two U-Nets ([Ronneberger et. al. 2015](https://arxiv.org/abs/1505.04597)) are used together as a single autoencoder. * The first U-Net encodes the input image into a k-way soft segmentation. The second U-Net reverses this process and produces a reconstructed image from the segmentation. * Both the reconstruction error of the autoencoder as well as a soft normalized cut ([Shi et. al. 2000](https://dl.acm.org/citation.cfm?id=351611)) loss function on the segmented output are jointly optimized. * To improve the segmentation, postprocessing is done in 2 steps: * fully connected conditional random field (CRF) ([Krahenbuhl et. al. 2012](https://arxiv.org/abs/1210.5644), [Chen. et. al. 2017](https://arxiv.org/abs/1606.00915)) smoothing on the output segments. * hierarchical merging method ([Arbelaez et. al. 2011](https://ieeexplore.ieee.org/document/5557884)) to obtain the final segmentation. ## Network Architecture ![W-Net Architecture](https://i.imgur.com/9SCjqWd.png) * The W-Net architecture has 46 convolutional layers which are structured into 18 modules marked with the red rectangles. * Each module consists of two 3x3 convolutional layers, each followed by a ReLU non-linearity and batch normalization. The first nine modules form the dense prediction base of the network and the second 9 correspond to the reconstruction decoder. * The encoder $U_{Enc}$ has a contracting path (1st half) to capture context and a corresponding expansive path (2nd half) which enables precise localization (as in the original U-Net paper). * Skip connections from contracting path to expansive path (as in original U-Net paper) are used to recover lost spatial information due to downsampling. * The architecture of $U_{Dec}$ is similar to $U_{Enc}$, except that it reads the output of the $U_{Enc}$ which has different number of channels (w.r.t. input). * One major modification is the use of depthwise separable convolution layers ([Chollet 2016](https://arxiv.org/abs/1610.02357)) in all modules except 1, 9, 10, and 18. A [good explanation](https://towardsdatascience.com/a-basic-introduction-to-separable-convolutions-b99ec3102728) of depthwise separable convolutions. * The idea behind such an operation is to examine spatial correlations and cross-channel correlations independently: * A depthwise convolution performs spatial convolutions independently over each channel . * Then, a pointwise convolution projects the feature channels by the depthwise convolution onto a new channel space. * The network does not include any fully connected layers which allows it to process arbitrarily sized images and make a segmentation prediction of the corresponding size. ### Soft Normalized Cut Loss * The output of $U_{Enc}$ is a normalized $H \times W \times K$ dense prediction i.e. $K$-class pixel-wise softmax prediction. By taking the **argmax**, we can obtain a class prediction for each pixel. * Normalized cut ($Ncut$) is computed as a global criterion for segmentation: $$ Ncut_K(V) = \sum_{k = 1}^K\frac{cut(A_k, V - A_k)}{assoc(A_k, V)}\\ = \sum_{k = 1}^K \frac{\sum_{u \in A_k, v \in V - A_k}w(u, v)}{\sum_{u \in A_k, t \in V}w(u, t)} $$ where $A_k$ is set of pixels in segment $k$, $V$ is the set of all pixels and $w$ measures the weight between two pixels. * **Intuition for above equations** ([useful tutorial](http://www.sci.utah.edu/~gerig/CS7960-S2010/handouts/Normalized%20Graph%20cuts.pdf), also refer to the Normalized Cut original paper): * The problem is considered as a weighted graph i.e. pixels are nodes of the graph and weights b/w every 2 pixels (nodes) are available. * The outer summation ($k = 1$ to $K$) is a summation over all the K segments i.e. one segment is considered at a time in the inner terms. * Numerator of inner term represents the sum of weights b/w pixels of the particular segment and pixels not in the particular segment. * Denominator of inner term represents the sum of weights b/w pixels of the particular segment and all pixels in the image. * Why normalized? Since we divide by the sum of weights b/w segment pixels and all pixels, it is a normalized calculation. Using only the numerator term is also an option, but it is not normalized. That doesn't work well (see the above tutorial link for understanding). * But, the argmax function is not differentiable so backprop will not work. So, a soft version of the $Ncut$ loss is defined: $$ J_{soft-Ncut}(V, K) = \sum_{k = 1}^K\frac{cut(A_k, V - A_k)}{assoc(A_k, V)} \\ = K - \sum_{k = 1}^K\frac{assoc(A_k, A_k)}{assoc(A_k, V)} \\ = K - \sum_{k = 1}^K\frac{\sum_{u \in V, v \in V}w(u, v)p(u = A_k)p(v = A_k)}{\sum_{u, t \in V}w(u, t)p(u = A_k)} \\ = K - \sum_{k = 1}^K\frac{\sum_{u \in V}p(u = A_k)\sum_{v \in V}w(u, v)p(v = A_k)}{\sum_{u \in V}p(u = A_k)\sum_{t \in V}w(u, t)} $$ * **Intuition for above equations:** * The earlier equations had $u \in A_k$ and $v \in A_k$ terms. Those require an argmax calculation i.e. we should know absolutely which pixel belongs to which class. * However, we have a probability distribution (softmax), so we change the $u \in A_k$ to $u \in V$ while adding a $p(u = A_k)$ term (same for $v$). * Thus, this change results in a soft version of the Normalized Cut function. * In the last 2 equations, the summations are simply split depending on the variable in question. * The inner numerator effectively computes the association between pixels of the same segment. The overall fraction is the normalized association. Thus, we would like to maximize this quantity. * So, the negative of normalized association has to be minimized. Having that (negative) term added to $K$ won't make a difference to the optimization problem. * No additional computation of probability terms is required, since those are computed by the encoder $U_{Enc}$ by default. * By training $U_{Enc}$ to minimize the $J_{soft-Ncut}$ loss, we can simultaneously minimize the total normalized dis-association between the groups and maximize the total normalized association within the groups. * Regarding the weights $w(u, v)$: **(Not mentioned explicitly in this paper)** (Need to refer to original paper on Normalized Cut) * For the monocular case, the graph is constructed by taking each pixel as a node and define the edge weight $w(i, j)$ between node $i$ and $j$ as the product of a feature similarity term and spatial proximity term: $$ w(i, j) = e^{\frac{-||F(i) - F(j)||_2^2}{\sigma_I}} * \begin{cases} e^{\frac{-||X(i) - X(j)||_2^2}{\sigma_X}} & \text{if $||X(i) - X(j)||_2 < r$} \\ 0 & \text{otherwise} \end{cases} $$ where $X(i)$ is the spatial location of node $i$, and $F(i)$ is a feature vector based on intensity, color, or texture information at that node (for various definitions of $F(i)$, refer to the original paper, here mostly $F(i)$ can be taken as the intensity of $i$). * Note that weight is $0$ for any nodes more than $r$ pixels apart. Also note that $\sigma_I$, $\sigma_X$ and $r$ are set manually (so they are hyperparameters). ### Reconstruction Loss * Similar to classical encoder-decoder architecture, reconstruction loss needs to be minimized to enforce that encoded information contain as much information of the original inputs as possible. * Minimizing the reconstruction loss in the W-Net architecture makes the segmentation prediction align better with the input images. * The loss is given by: $$ J_{reconstr} = ||X - U_{Dec}(U_{Enc}(X; W_{Enc}); W_{Dec})||_2^2 $$ where $W_{Enc}$ and $W_{Dec}$ are the parameters of encoder and decoder respectively, and $X$ is the input image. ## Optimization * By iteratively applying $J_{reconstr}$ and $J_{soft-Ncut}$, the network balances the trade-off between the accuracy of reconstruction and the consistency in the encoded representation layer. * The algorithm is shown below: ![W-Net mini-batch SGD training](https://i.imgur.com/Anc2ShE.png) ## Postprocessing Don't understand these parts (CRF and hierarchical clustering) yet. Hopefully will be able to update this someday :cry:.