Try   HackMD

Notes on "Semi-Supervised Semantic Segmentation with Cross-Consistency Training"

tags: notes segmentation semi-supervised

CVPR '20 paper; Official Code Release

Author: Akshay Kulkarni

Brief Outline

This paper proposes cross-consistency training, where an invariance of the predictions is enforced over different perturbations applied to the outputs of the encoder (in a shared encoder and multiple decoder architecture).

Introduction

  • Semi-Supervised Learning (SSL) takes advantage of a large amount of unlabeled data and limits the need for labeled examples. The current dominant SSL methods in deep learning are
  • Weakly-supervised approaches require weakly labeled examples along with pixel-level labels, and hence they don't exploit the unlabeled data to extract additional training signal.
  • Methods based on adversarial training exploit the unlabeled data, but can be harder to train.
  • To address these limitations, they propose a simple consistency-based semi-supervised method for semantic segmentation. The objective of consistency training is to enforce an invariance of the model's prediction to small perturbations applied to the inputs, and make the model robust to such changes.
  • They consider a shared encoder and a main decoder which are trained using the labeled data. To use unlabeled data, they have auxiliary decoders whose inputs are perturbed versions of the output of the shared encoder.
  • Their contributions are as follows:
    • They propose a cross-consistency training (CCT) method where invariance of predictions is enforced over different perturbations injected into the encoder's output.
    • They conduct an exhaustive study of various types of perturbations.
    • They extend their approach to use weakly-labeled data, and exploit pixel-level labels across different domains to jointly train the segmentation network.
    • They compare with the SOTA and analyze the approach through an ablation study.

Methodology

Cluster Assumption in Semantic Segmentation

  • Check the Cluster Assumption page on Wikipedia. Basically, the assumption is that if points are in the same cluster, they are likely to be of the same class. There maybe multiple clusters forming a single class.
  • It is equivalent to the low density separation assumption which states that decision boundary should lie on a low density region. To prove it, suppose the decision boundary crosses a cluster (which is high density region), then the cluster will contain points from 2 different classes, which violates the cluster assumption.

Note: Above two points are taken from the linked Wikipedia page, and are not in the paper.

  • A simple way to examine the cluster assumption is to estimate the local smoothness by measuring the local variations between the value of each pixel and its local neighbours.
  • For this, they compute the average Euclidean distance at each spatial location and its 8 intermediate neighbours, for both inputs and the hidden representations. Check the paper for exact computation details.

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

  • They observe that the cluster assumption is violated at the input level since the low density regions do not align with the class boundaries.
  • For the encoder's outputs, the cluster assumption is maintained where the class boundaries have high average distance, corresponding to low density regions.
  • These observations motivate the approach to apply perturbations to the encoder's outputs rather than the inputs.

Cross-Consistency Training

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

  • Let
    Dl={(x1l,y1),,(xnl,yn)}
    represent the
    n
    labeled examples and
    Du={x1u,,xmu}
    represent the
    m
    unlabeled examples, and
    xiu
    as the
    i
    -th unlabeled input image, and
    xil
    as the
    i
    -th labeled input image with spatial dimensions
    H×W
    and its corresponding pixel-level label
    yiRC×H×W
    , where
    C
    is the number of classes.
  • The architecture is composed of a shared encoder
    h
    and a main decoder
    g
    which constitute the segmentation network
    f=gh
    . They also introduce a set of auxiliary decoders
    gak
    with
    k[1,K]
    .
  • While the segmentation network
    f
    is trained on the labeled set
    Dl
    in a traditional supervised manner, the auxiliary networks
    gakh
    are trained on the unlabeled set
    Du
    by enforcing a consistency of predictions between the main decoder and the auxiliary decoders.
  • Each auxiliary decoder takes as input a perturbed version of the encoder's output and the main decoder is fed the uncorrupted intermediate representation. This way, the representation learning of the encoder
    h
    is further enhanced using the unlabeled examples.
  • Formally, for a labeled training example
    xil
    and its pixel-level label
    yi
    , the network
    f
    is trained using a CE based supervised loss
    Ls

(1)Ls=1|Dl|xil,yiDlH(yi,f(xil))

  • Here,
    H
    is the CE loss. For an unlabeled example
    xiu
    , an intermediate representation of the input is computed using the shared encoder
    zi=h(xiu)
    . Consider
    R
    stochastic perturbation functions, denoted as
    pr
    with
    r[1,R]
    , where one perturbation function can be assigned to multiple auxiliary decoders.
  • With various perturbation settings, they generate
    K
    perturbed versions
    z~ik
    of the intermediate representation
    zi
    , so that the
    k
    -th perturbed version is fed to the
    k
    -th auxiliary decoder. For simplicity of notations, they consider the perturbation function as part of the auxiliary decoder (i.e.
    gak
    is seen as
    gakpr
    ).
  • The training objective is to minimize the unsupervised loss
    Lu
    , which measures the discrepancy between the main decoder's output and that of the auxiliary decoders

(2)Lu=1|Du|KxiuDuk=1KMSE(g(zi),gak(zi))

  • The combined loss
    L
    for consistency based SSL is then computed as

(3)L=Ls+ωuLu

  • Here,
    ωu
    is an unsupervised loss weighting function. Following this ICLR '17 work, to avoid using the initial noisy predictions of the main encoder,
    ωu
    ramps up starting from zero along a Gaussian curve up to a fixed weight
    λu
    .
  • At each training iteration, an equal number of examples are sampled from the labeled
    Dl
    and the unlabeled
    Du
    sets. Note that the unsupervised loss
    Lu
    is not back-propagated through the main decoder
    g
    , only the labeled examples are used to train
    g
    .

Perturbation Functions

  • They propose three types of perturbation functions
    pr
    • Feature-based
    • Prediction-based
    • Random
  • Feature-based Perturbations
    • They consist of either injecting noise into or dropping some of the activations of encoder's output feature map
      z
      .
    • F-Noise
      • They uniformly sample a noise tensor
        NU(0.3,0.3)
        of the same size as
        z
        .
      • After adjusting its amplitude by multiplying it with
        z
        , the noise is then injected into the encoder output
        z
        to get
        z~=(zN)+z
        .
      • This way, the injected noise is proportional to each activation.
    • F-Drop
      • First, uniformly sample a threshold
        γU(0.6,0.9)
        .
      • After summing over the channel dimensions and normalizing the feature map
        z
        to get
        z
        , they generate a binary mask
        Mdrop={z<γ}
        , which is used to obtain the perturbed version
        z~=zMdrop
        .
      • This way, they mask 10% to 40% of the most active regions in the feature map.
  • Prediction-based Perturbations
    • They consist of adding perturbations based on the main decoder's prediction
      y^=g(z)
      or that of the auxiliary decoders. They consider masking-based perturbations in addition to adversarial perturbations.
    • Guided Masking
      • Context relationships are important for complex scene understanding, and the network may be too reliant on these relationships.
      • To limit them, they create 2 perturbed versions of
        z
        by masking the detected objects (
        Obj-Msk
        ) and the context (
        Con-Msk
        ).
      • Using
        y^
        , they generate an object mask
        Mobj
        to mask the detected foreground objects and a context mask
        Mcon=1Mobj
        , which are then applied to
        z
        .
    • Guided Cutout (
      G-Cutout
      )
      • In order to reduce reliance on specific parts of the objects, and inspired by Cutout that randomly masks some parts of the input image, they first find the possible spatial extent (i.e. bounding box) of each detected object using
        y^
        .
      • Then they zero-out a random crop within each object's bounding box from the corresponding feature map
        z
        .
    • Intermediate VAT (
      I-VAT
      )
      • To further push the output distribution to be isotropically smooth around each data point, they use VAT (TPAMI '18) as a perturbation function to be applied to
        z
        instead of the unlabeled inputs.
      • For a given auxiliary decoder, they find the adversarial perturbation
        radv
        that alters its prediction the most, and inject it into
        z
        to obtain the perturbed
        z~=radv+z
        .
  • Random Perturbations (
    DropOut
    )

Practical Considerations

  • At each training iteration, they sample equal number of labeled and unlabeled samples. Due to smaller size of labeled set, iteration on
    Dl
    will be much more than
    Du
    , thus risking overfitting on the labeled set
    Dl
    .
  • They propose an annealed version of the bootstrapped-CE (
    ab-CE
    ) (CVPR '17). With an output
    f(xil)RC×H×W
    in the form of a probability distribution over the pixels, they compute the supervised loss over the pixels with a probability less than a threshold
    η

(4)Ls=1|Dl|xil,yiDl{f(xil)<η}1H(yi,f(xil))

  • To release the supervised training signal, the threshold parameter
    η
    is gradually increased from
    1C
    to
    0.9
    , where
    C
    is the number of output classes.

Exploiting weak-labels

  • Check Section 3.3 in the paper for how they incorporate weak-labels in a similar manner.

Cross-Consistency Training on Multiple Domains

  • They extend the proposed framework to a semi-supervised domain adaptation setting. Consider the case of 2 datasets
    {D(1),D(2)}
    with partially or fully non-overlapping label spaces, and each contains a set of labeled and unlabeled examples
    D(i)={Dl(i),Du(i)}
    .
  • Their assumption is that enforcing a consistency over both unlabeled sets
    Du(1)
    and
    Du(2)
    might impose an invariance on the encoder's representations across the 2 domains.
  • For this, they add domain specific main decoder
    g(i)
    and auxiliary decoders
    gak(i)
    on top of the shared encoder
    h
    , as shown below.

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

  • During training, they alternate between the two datasets at each iteration, sampling an equal number of labeled and unlabeled examples from each one, computing the loss in Eq. 3 and training the shared encoder and the corresponding main and auxiliary decoders.

Conclusion

  • They present Cross-Consistency Training (CCT) for consistency-based semi-supervised semantic segmentation which yields SOTA results.
  • Further works may focus on exploring the usage of other perturbations applied at different levels within the segmentation network.
  • Other directions may be to study the effectiveness of CCT in other visual tasks and learning settings like UDA.