This paper proposes cross-consistency training, where an invariance of the predictions is enforced over different perturbations applied to the outputs of the encoder (in a shared encoder and multiple decoder architecture).
Introduction
Semi-Supervised Learning (SSL) takes advantage of a large amount of unlabeled data and limits the need for labeled examples. The current dominant SSL methods in deep learning are
Weakly-supervised approaches require weakly labeled examples along with pixel-level labels, and hence they don't exploit the unlabeled data to extract additional training signal.
Methods based on adversarial training exploit the unlabeled data, but can be harder to train.
To address these limitations, they propose a simple consistency-based semi-supervised method for semantic segmentation. The objective of consistency training is to enforce an invariance of the model's prediction to small perturbations applied to the inputs, and make the model robust to such changes.
They consider a shared encoder and a main decoder which are trained using the labeled data. To use unlabeled data, they have auxiliary decoders whose inputs are perturbed versions of the output of the shared encoder.
Their contributions are as follows:
They propose a cross-consistency training (CCT) method where invariance of predictions is enforced over different perturbations injected into the encoder's output.
They conduct an exhaustive study of various types of perturbations.
They extend their approach to use weakly-labeled data, and exploit pixel-level labels across different domains to jointly train the segmentation network.
They compare with the SOTA and analyze the approach through an ablation study.
Methodology
Cluster Assumption in Semantic Segmentation
Check the Cluster Assumption page on Wikipedia. Basically, the assumption is that if points are in the same cluster, they are likely to be of the same class. There maybe multiple clusters forming a single class.
It is equivalent to the low density separation assumption which states that decision boundary should lie on a low density region. To prove it, suppose the decision boundary crosses a cluster (which is high density region), then the cluster will contain points from 2 different classes, which violates the cluster assumption.
Note: Above two points are taken from the linked Wikipedia page, and are not in the paper.
A simple way to examine the cluster assumption is to estimate the local smoothness by measuring the local variations between the value of each pixel and its local neighbours.
For this, they compute the average Euclidean distance at each spatial location and its 8 intermediate neighbours, for both inputs and the hidden representations. Check the paper for exact computation details.
They observe that the cluster assumption is violated at the input level since the low density regions do not align with the class boundaries.
For the encoder's outputs, the cluster assumption is maintained where the class boundaries have high average distance, corresponding to low density regions.
These observations motivate the approach to apply perturbations to the encoder's outputs rather than the inputs.
Let represent the labeled examples and represent the unlabeled examples, and as the -th unlabeled input image, and as the -th labeled input image with spatial dimensions and its corresponding pixel-level label , where is the number of classes.
The architecture is composed of a shared encoder and a main decoder which constitute the segmentation network . They also introduce a set of auxiliary decoders with .
While the segmentation network is trained on the labeled set in a traditional supervised manner, the auxiliary networks are trained on the unlabeled set by enforcing a consistency of predictions between the main decoder and the auxiliary decoders.
Each auxiliary decoder takes as input a perturbed version of the encoder's output and the main decoder is fed the uncorrupted intermediate representation. This way, the representation learning of the encoder is further enhanced using the unlabeled examples.
Formally, for a labeled training example and its pixel-level label , the network is trained using a CE based supervised loss
Here, is the CE loss. For an unlabeled example , an intermediate representation of the input is computed using the shared encoder . Consider stochastic perturbation functions, denoted as with , where one perturbation function can be assigned to multiple auxiliary decoders.
With various perturbation settings, they generate perturbed versions of the intermediate representation , so that the -th perturbed version is fed to the -th auxiliary decoder. For simplicity of notations, they consider the perturbation function as part of the auxiliary decoder (i.e. is seen as ).
The training objective is to minimize the unsupervised loss , which measures the discrepancy between the main decoder's output and that of the auxiliary decoders
The combined loss for consistency based SSL is then computed as
Here, is an unsupervised loss weighting function. Following this ICLR '17 work, to avoid using the initial noisy predictions of the main encoder, ramps up starting from zero along a Gaussian curve up to a fixed weight .
At each training iteration, an equal number of examples are sampled from the labeled and the unlabeled sets. Note that the unsupervised loss is not back-propagated through the main decoder , only the labeled examples are used to train .
Perturbation Functions
They propose three types of perturbation functions
Feature-based
Prediction-based
Random
Feature-based Perturbations
They consist of either injecting noise into or dropping some of the activations of encoder's output feature map .
F-Noise
They uniformly sample a noise tensor of the same size as .
After adjusting its amplitude by multiplying it with , the noise is then injected into the encoder output to get .
This way, the injected noise is proportional to each activation.
F-Drop
First, uniformly sample a threshold .
After summing over the channel dimensions and normalizing the feature map to get , they generate a binary mask , which is used to obtain the perturbed version .
This way, they mask 10% to 40% of the most active regions in the feature map.
Prediction-based Perturbations
They consist of adding perturbations based on the main decoder's prediction or that of the auxiliary decoders. They consider masking-based perturbations in addition to adversarial perturbations.
Guided Masking
Context relationships are important for complex scene understanding, and the network may be too reliant on these relationships.
To limit them, they create 2 perturbed versions of by masking the detected objects () and the context ().
Using , they generate an object mask to mask the detected foreground objects and a context mask , which are then applied to .
Guided Cutout ()
In order to reduce reliance on specific parts of the objects, and inspired by Cutout that randomly masks some parts of the input image, they first find the possible spatial extent (i.e. bounding box) of each detected object using .
Then they zero-out a random crop within each object's bounding box from the corresponding feature map .
Intermediate VAT ()
To further push the output distribution to be isotropically smooth around each data point, they use VAT (TPAMI '18) as a perturbation function to be applied to instead of the unlabeled inputs.
For a given auxiliary decoder, they find the adversarial perturbation that alters its prediction the most, and inject it into to obtain the perturbed .
At each training iteration, they sample equal number of labeled and unlabeled samples. Due to smaller size of labeled set, iteration on will be much more than , thus risking overfitting on the labeled set .
They propose an annealed version of the bootstrapped-CE () (CVPR '17). With an output in the form of a probability distribution over the pixels, they compute the supervised loss over the pixels with a probability less than a threshold
To release the supervised training signal, the threshold parameter is gradually increased from to , where is the number of output classes.
Exploiting weak-labels
Check Section 3.3 in the paper for how they incorporate weak-labels in a similar manner.
Cross-Consistency Training on Multiple Domains
They extend the proposed framework to a semi-supervised domain adaptation setting. Consider the case of 2 datasets with partially or fully non-overlapping label spaces, and each contains a set of labeled and unlabeled examples .
Their assumption is that enforcing a consistency over both unlabeled sets and might impose an invariance on the encoder's representations across the 2 domains.
For this, they add domain specific main decoder and auxiliary decoders on top of the shared encoder , as shown below.
During training, they alternate between the two datasets at each iteration, sampling an equal number of labeled and unlabeled examples from each one, computing the loss in Eq. 3 and training the shared encoder and the corresponding main and auxiliary decoders.
Conclusion
They present Cross-Consistency Training (CCT) for consistency-based semi-supervised semantic segmentation which yields SOTA results.
Further works may focus on exploring the usage of other perturbations applied at different levels within the segmentation network.
Other directions may be to study the effectiveness of CCT in other visual tasks and learning settings like UDA.