SOTA on sem. seg. UDA benchmarks (as of 10/04/21). However, their code shows use of DeepLabv3+ type of decoder instead of DeepLabv2 (as claimed in the paper). Thus, it is unclear whether performance improvements are architecture-dependent or not.
Brief Outline
They use representative prototypes (class feature centroids) to address 2 issues in self-training for UDA in semantic segmentation.
Exploit the feature distances from prototypes to estimate likelihood of pseudo-labels to facilitate online correction during training.
Align prototypical assignments based on relative feature distances for 2 different views of the same target, producing a more compact target feature space.
Further, distilling the already learned knowledge to a self-supervised pretrained model further boosts performance.
Introduction
Self-training has recently emerged as a simple yet competitive approach for UDA rather than explicitly aligning the distributions of source and targets (adversarial alignment).
Two key ingredients are lacking in self-training:
Typical practices select pseudo-labels according to a strict confidence threshold. Since high scores are not necessarily correct, network fails to learn reliable knowledge in the target domain.
Due to the domain gap, network is prone to produce dispersed features in the target domain. It is likely that for target data, the closer to the source distribution, the higher the confidence score. So, data far from source distribution (i.e. low scores) will never be considered during training.
This work proposes to online denoise pseudo-labels and learn a compact target structure to address the above 2 issues respectively. They use prototypesi.e. class-wise feature centroids to accomplish the 2 tasks:
Rectify pseudo-labels by estimating the class-wise likelihoods according to its relative feature distances to all class prototypes. Prototypes are computed on-the-fly and thus, pseudo-labels are progressively corrected throughout training.
Inspired by Deepcluster (Caron et al. ECCV '18), they learn the intrinsic structure of the target domain. They propose to align soft prototypical assignments for different views of the same target, which produces a more compact target feature space.
They call their method ProDA as they rely heavily on prototypes for DA.
Further, they find that DA can benefit from task-agnostic pretraining. Distilling the knowledge to a self-supervised model (SimCLRv2, NeurIPS '20) further boosts the performance.
Methodology
Preliminaries
Given source dataset with labels , aim is to train a segmentation network to achieve low risk on the unlabeled target dataset where classes are same across domains.
Typically, source-trained models cannot generalize well to target data. To transfer the knowledge, traditional self-training techniques optimize the categorical cross-entropy (CE) with pseudo-labels :
Typically, the most probable class predicted by the source n/w is used as pseudo-labels:
This conversion from soft predictions to hard labels is denoted by . Further, in practice, only pixels whose prediction confidence exceeds a given threshold are used as pseudo-labels (due to noise in predictions).
Prototypical pseudo-label denoising
Updating the pseudo-label after one training stage is too late as n/w may have already overfitted the noisy labels. On the other hand, simultaneously updating pseudo-labels and n/w weights is prone to give trivial solutions.
The key is to fix the soft pseudo-labels and progressively weight them by class-wise probabilities, with the update in accordance with freshly learned knowledge. Formally, they propose to use the weighted pseudo-labels for self-training:
Here, is the proposed weight for modulating the probability (and it changes as training proceeds), whereas is initialized by the source model and remains fixed throughout training (boiler-plate for subsequent refinement).
They use distances from the prototypes to gradually rectify the pseudo-labels. Let represent the feature of at index . If it is far from the prototype (feature centroid of class ), it is more probable that the learned feature is an outlier, hence downweight its probability of being classified into category.
Concretely, the modulation weight is defined as the softmax over feature distances to prototypes:
Here, denotes momentum encoder (He et al. CVPR '20) of feature extractor as a reliable feature estimation of is desired, and is softmax temperature (empirically, ). In other words, approximates the trust confidence of pixel belonging to the class.
Note about momentum encoder: is architecturally same as . However, is updated at every iteration using the gradients from backprop. On the other hand, is updated as an exponential moving average of the weights of i.e.. Thus, cannot be used to backpropagate the loss, while can.
Prototype computation
The proposed method requires computation of prototypes on-the-fly. At the beginning, prototypes are initialized according to predicted pseudo-labels for target domain images as:
Here, is the indicator function. However, such computation on entire dataset is expensive during training. Thus, they estimate the prototypes as moving average of the cluster centroids in mini-batches to track the prototypes that slowly move. Formally,
Here, is the mean feature of class calculated within the current training batch from the momentum encoder and is the momentum coefficient set to .
Pseudo-label training loss
Instead of using standard CE loss, they use a more robust, symmetric cross-entropy (SCE) loss (Wang et al. ICCV '19) to further enhance the noise tolerance to stabilize the early training. Specifically, they enforce
Here, and are balancing coefficients.
Why are prototypes useful for pseudo-label denoising?
The prototypes are less sensitive to the outliers (wrong pseudo-labels) which are assumed to be the minority.
The prototypes treat different classes equally regardless of occurrence frequency, which is useful due to class-imbalance in semantic segmentation.
Also, see Fig. 1a for a visual explanation.
Structure learning by enforcing consistency
Pseudo-labels can be denoised when feature extractor generates compact target features. However, due to the domain gap, the generated target features are likely to be dispersed (Fig. 1b).
To achieve compact target features, they aim to learn the underlying structure of the target domain. They use the prototypical assignment under weak augmentation to guide the learning for the strong augmented view.
Let and respectively denote the weak and strong augmented views for . They use the momentum encoder to generate a reliable prototypical assignment for which is:
Similarly, the soft assignment for is obtained except that current trainable feature extractor is used. Since is more reliable as feature is from momentum encoder and input suffers less distortion, they use it to teach to produce consistent assignments for .
Hence, they minimize the KL divergence between the 2 prototypical assignments under 2 views:
Intuitively, this enforces the n/w to give consistent prototypical labeling for adjacent feature points, resulting in a more compact target feature space.
The proposed method may suffer from degeneration issue i.e. one cluster becomes empty. To amend this, they use a regularization term (Zou et al. ICCV '19) which encourages the output to be evenly distributed to different classes:
They train the DA n/w with the following total loss:
Distillation to self-supervised model
After training with Eq. 11 converges, they further transfer knowledge from the learned target model to a student model with the same architecture but pretrained in a self-supervised manner.
They initialize the student feature extractor with SimCLRv2 pretrained weights and apply a knowledge distillation (KD) loss (KL divergence loss).
Besides, following the self-training paradigm, the teacher model generates one-hot pseudo-labels to teach the student model.
To prevent the model forgetting the source domain, source images are also utilized. Altogether, the student model is trained with:
Here, is the output of the student model and . In practice, such self-distillation can be applied multiple times after model convergence to boost the DA performance further.