# Notes on "[Semi-Supervised Semantic Segmentation with Cross-Consistency Training](https://openaccess.thecvf.com/content_CVPR_2020/papers/Ouali_Semi-Supervised_Semantic_Segmentation_With_Cross-Consistency_Training_CVPR_2020_paper.pdf)"
###### tags: `notes` `segmentation` `semi-supervised`
CVPR '20 paper; [Official Code Release](https://github.com/yassouali/CCT)
Author: [Akshay Kulkarni](https://akshayk07.weebly.com/)
## Brief Outline
This paper proposes cross-consistency training, where an invariance of the predictions is enforced over different perturbations applied to the outputs of the encoder (in a shared encoder and multiple decoder architecture).
## Introduction
* Semi-Supervised Learning (SSL) takes advantage of a large amount of unlabeled data and limits the need for labeled examples. The current dominant SSL methods in deep learning are
* Consistency Training ([NIPS '15](https://papers.nips.cc/paper/5947-semi-supervised-learning-with-ladder-networks), [ICLR '17](https://arxiv.org/abs/1610.02242), [NIPS '17](https://arxiv.org/abs/1703.01780) and [TPAMI '18](https://arxiv.org/abs/1704.03976))
* Pseudo Labelling ([ICMLW '13](http://deeplearning.net/wp-content/uploads/2013/03/pseudo_label_final.pdf))
* Entropy Minimization ([NIPS '05](https://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization))
* Generative Modeling ([NIPS '17](https://papers.nips.cc/paper/7137-semi-supervised-learning-with-gans-manifold-invariance-with-improved-inference) and [ICCV '17](https://openaccess.thecvf.com/content_ICCV_2017/papers/Souly__Semi_Supervised_ICCV_2017_paper.pdf))
* Weakly-supervised approaches require weakly labeled examples along with pixel-level labels, and hence they don't exploit the unlabeled data to extract additional training signal.
* Methods based on adversarial training exploit the unlabeled data, but can be harder to train.
* To address these limitations, they propose a simple consistency-based semi-supervised method for semantic segmentation. The objective of consistency training is to enforce an invariance of the model's prediction to small perturbations applied to the inputs, and make the model robust to such changes.
* They consider a shared encoder and a main decoder which are trained using the labeled data. To use unlabeled data, they have auxiliary decoders whose inputs are perturbed versions of the output of the shared encoder.
* Their contributions are as follows:
* They propose a cross-consistency training (CCT) method where invariance of predictions is enforced over different perturbations injected into the encoder's output.
* They conduct an exhaustive study of various types of perturbations.
* They extend their approach to use weakly-labeled data, and exploit pixel-level labels across different domains to jointly train the segmentation network.
* They compare with the SOTA and analyze the approach through an ablation study.
## Methodology
### Cluster Assumption in Semantic Segmentation
* Check the [Cluster Assumption page on Wikipedia](https://en.wikipedia.org/wiki/Cluster_hypothesis). Basically, the assumption is that if points are in the same cluster, they are likely to be of the same class. There maybe multiple clusters forming a single class.
* It is equivalent to the low density separation assumption which states that decision boundary should lie on a low density region. To prove it, suppose the decision boundary crosses a cluster (which is high density region), then the cluster will contain points from 2 different classes, which violates the cluster assumption.
Note: Above two points are taken from the linked Wikipedia page, and are not in the paper.
* A simple way to examine the cluster assumption is to estimate the local smoothness by measuring the local variations between the value of each pixel and its local neighbours.
* For this, they compute the average Euclidean distance at each spatial location and its 8 intermediate neighbours, for both inputs and the hidden representations. Check the paper for exact computation details.
![Cluster Assumption Evaluation](https://i.imgur.com/983ee7Q.png)
* They observe that the cluster assumption is violated at the input level since the low density regions do not align with the class boundaries.
* For the encoder's outputs, the cluster assumption is maintained where the class boundaries have high average distance, corresponding to low density regions.
* These observations motivate the approach to apply perturbations to the encoder's outputs rather than the inputs.
### Cross-Consistency Training
![CCT Approach](https://i.imgur.com/eilPWpq.png)
* Let $\mathcal{D}_l = \{(\mathrm{x}_1^l, y_1), \dots, (\mathrm{x}_n^l, y_n)\}$ represent the $n$ labeled examples and $\mathcal{D}_u = \{ \mathrm{x}_1^u, \dots, \mathrm{x}_m^u \}$ represent the $m$ unlabeled examples, and $\mathrm{x}_i^u$ as the $i$-th unlabeled input image, and $\mathrm{x}_i^l$ as the $i$-th labeled input image with spatial dimensions $H\times W$ and its corresponding pixel-level label $y_i \in \mathbb{R}^{C \times H \times W}$, where $C$ is the number of classes.
* The architecture is composed of a shared encoder $h$ and a main decoder $g$ which constitute the segmentation network $f=g\circ h$. They also introduce a set of auxiliary decoders $g_a^k$ with $k \in [1, K]$.
* While the segmentation network $f$ is trained on the labeled set $\mathcal{D}_l$ in a traditional supervised manner, the auxiliary networks $g_a^k \circ h$ are trained on the unlabeled set $\mathcal{D}_u$ by enforcing a consistency of predictions between the main decoder and the auxiliary decoders.
* Each auxiliary decoder takes as input a perturbed version of the encoder's output and the main decoder is fed the uncorrupted intermediate representation. This way, the representation learning of the encoder $h$ is further enhanced using the unlabeled examples.
* Formally, for a labeled training example $\mathrm{x}_i^l$ and its pixel-level label $y_i$, the network $f$ is trained using a CE based supervised loss $\mathcal{L}_s$
$$
\mathcal{L}_s = \frac{1}{|\mathcal{D}_l|} \sum_{\mathrm{x}_i^l, y_i \in \mathcal{D}_l} \mathrm{H}(y_i, f(\mathrm{x}_i^l))
\tag{1}
$$
* Here, $\mathrm{H}$ is the CE loss. For an unlabeled example $\mathrm{x}_i^u$, an intermediate representation of the input is computed using the shared encoder $\mathrm{z}_i = h(\mathrm{x}_i^u)$. Consider $R$ stochastic perturbation functions, denoted as $p_r$ with $r \in [1, R]$, where one perturbation function can be assigned to multiple auxiliary decoders.
* With various perturbation settings, they generate $K$ perturbed versions $\tilde{\mathrm{z}}_i^k$ of the intermediate representation $\mathrm{z}_i$, so that the $k$-th perturbed version is fed to the $k$-th auxiliary decoder. For simplicity of notations, they consider the perturbation function as part of the auxiliary decoder (i.e. $g_a^k$ is seen as $g_a^k \circ p_r$).
* The training objective is to minimize the unsupervised loss $\mathcal{L}_u$, which measures the discrepancy between the main decoder's output and that of the auxiliary decoders
$$
\mathcal{L}_u = \frac{1}{|\mathcal{D}_u| K} \sum_{\mathrm{x}_i^u \in \mathcal{D}_u} \sum_{k=1}^K \text{MSE}(g(\mathrm{z}_i), g_a^k(\mathrm{z}_i))
\tag{2}
$$
* The combined loss $\mathcal{L}$ for consistency based SSL is then computed as
$$
\mathcal{L} = \mathcal{L}_s + \omega_u \mathcal{L}_u
\tag{3}
$$
* Here, $\omega_u$ is an unsupervised loss weighting function. Following this [ICLR '17 work](https://arxiv.org/abs/1610.02242), to avoid using the initial noisy predictions of the main encoder, $\omega_u$ ramps up starting from zero along a Gaussian curve up to a fixed weight $\lambda_u$.
* At each training iteration, an equal number of examples are sampled from the labeled $\mathcal{D}_l$ and the unlabeled $\mathcal{D}_u$ sets. Note that the unsupervised loss $\mathcal{L}_u$ is not back-propagated through the main decoder $g$, only the labeled examples are used to train $g$.
#### Perturbation Functions
* They propose three types of perturbation functions $p_r$
* Feature-based
* Prediction-based
* Random
* Feature-based Perturbations
* They consist of either injecting noise into or dropping some of the activations of encoder's output feature map $\mathrm{z}$.
* F-Noise
* They uniformly sample a noise tensor $\mathrm{N} \sim \mathcal{U}(-0.3, 0.3)$ of the same size as $\mathrm{z}$.
* After adjusting its amplitude by multiplying it with $\mathrm{z}$, the noise is then injected into the encoder output $\mathrm{z}$ to get $\tilde{\mathrm{z}} = (\mathrm{z} \odot \mathrm{N}) + \mathrm{z}$.
* This way, the injected noise is proportional to each activation.
* F-Drop
* First, uniformly sample a threshold $\gamma \sim \mathcal{U}(0.6, 0.9)$.
* After summing over the channel dimensions and normalizing the feature map $\mathrm{z}$ to get $\mathrm{z}'$, they generate a binary mask $\mathrm{M}_{\text{drop}} = \{\mathrm{z}' < \gamma\}$, which is used to obtain the perturbed version $\tilde{\mathrm{z}} = \mathrm{z} \odot \mathrm{M}_{\text{drop}}$.
* This way, they mask 10% to 40% of the most active regions in the feature map.
* Prediction-based Perturbations
* They consist of adding perturbations based on the main decoder's prediction $\hat{y} = g(\mathrm{z})$ or that of the auxiliary decoders. They consider masking-based perturbations in addition to adversarial perturbations.
* Guided Masking
* Context relationships are important for complex scene understanding, and the network may be too reliant on these relationships.
* To limit them, they create 2 perturbed versions of $\mathrm{z}$ by masking the detected objects ($\text{Obj-Msk}$) and the context ($\text{Con-Msk}$).
* Using $\hat{y}$, they generate an object mask $\mathrm{M}_{\text{obj}}$ to mask the detected foreground objects and a context mask $\mathrm{M}_{\text{con}} = 1 - \mathrm{M}_{\text{obj}}$, which are then applied to $\mathrm{z}$.
* Guided Cutout ($\text{G-Cutout}$)
* In order to reduce reliance on specific parts of the objects, and inspired by [Cutout](https://arxiv.org/abs/1708.04552) that randomly masks some parts of the input image, they first find the possible spatial extent (i.e. bounding box) of each detected object using $\hat{y}$.
* Then they zero-out a random crop within each object's bounding box from the corresponding feature map $\mathrm{z}$.
* Intermediate VAT ($\text{I-VAT}$)
* To further push the output distribution to be isotropically smooth around each data point, they use [VAT (TPAMI '18)](https://arxiv.org/abs/1704.03976) as a perturbation function to be applied to $\mathrm{z}$ instead of the unlabeled inputs.
* For a given auxiliary decoder, they find the adversarial perturbation $r_{adv}$ that alters its prediction the most, and inject it into $\mathrm{z}$ to obtain the perturbed $\tilde{\mathrm{z}} = r_{adv} + \mathrm{z}$.
* Random Perturbations ($\text{DropOut}$)
* [Spatial dropout (CVPR '15)](https://openaccess.thecvf.com/content_cvpr_2015/html/Tompson_Efficient_Object_Localization_2015_CVPR_paper.html) is also applied to $\mathrm{z}$ as a random perturbation.
#### Practical Considerations
* At each training iteration, they sample equal number of labeled and unlabeled samples. Due to smaller size of labeled set, iteration on $\mathcal{D}_l$ will be much more than $\mathcal{D}_u$, thus risking overfitting on the labeled set $\mathcal{D}_l$.
* They propose an annealed version of the bootstrapped-CE ($\text{ab-CE}$) ([CVPR '17](https://arxiv.org/abs/1611.08323)). With an output $f(\mathrm{x}_i^l) \in \mathbb{R}^{C \times H \times W}$ in the form of a probability distribution over the pixels, they compute the supervised loss over the pixels with a probability less than a threshold $\eta$
$$
\mathcal{L}_s = \frac{1}{|\mathcal{D}_l|} \sum_{\mathrm{x}_i^l, y_i \in \mathcal{D}_l} \{ f(\mathrm{x}_i^l) < \eta\}_1 \mathrm{H}(y_i, f(\mathrm{x}_i^l))
\tag{4}
$$
* To release the supervised training signal, the threshold parameter $\eta$ is gradually increased from $\frac{1}{C}$ to $0.9$, where $C$ is the number of output classes.
### Exploiting weak-labels
* Check Section 3.3 in the paper for how they incorporate weak-labels in a similar manner.
### Cross-Consistency Training on Multiple Domains
* They extend the proposed framework to a semi-supervised domain adaptation setting. Consider the case of 2 datasets $\{ \mathcal{D}^{(1)}, \mathcal{D}^{(2)} \}$ with partially or fully non-overlapping label spaces, and each contains a set of labeled and unlabeled examples $\mathcal{D}^{(i)} = \{\mathcal{D}_l^{(i)}, \mathcal{D}_u^{(i)} \}$.
* Their assumption is that enforcing a consistency over both unlabeled sets $\mathcal{D}_u^{(1)}$ and $\mathcal{D}_u^{(2)}$ might impose an invariance on the encoder's representations across the 2 domains.
* For this, they add domain specific main decoder $g^{(i)}$ and auxiliary decoders $g_a^{k(i)}$ on top of the shared encoder $h$, as shown below.
![CCT for DA](https://i.imgur.com/bSirGFa.png)
* During training, they alternate between the two datasets at each iteration, sampling an equal number of labeled and unlabeled examples from each one, computing the loss in Eq. 3 and training the shared encoder and the corresponding main and auxiliary decoders.
## Conclusion
* They present Cross-Consistency Training (CCT) for consistency-based semi-supervised semantic segmentation which yields SOTA results.
* Further works may focus on exploring the usage of other perturbations applied at different levels within the segmentation network.
* Other directions may be to study the effectiveness of CCT in other visual tasks and learning settings like UDA.