DI
stillation with NO
labels)There are two kinds of crops - global (used by teacher) and local. Student model uses both global and local crops.
This class defined the augmentations.
local_crops_number
(8 by default) number of local crops here: def __call__(self, image):
crops = []
crops.append(self.global_transfo1(image))
crops.append(self.global_transfo2(image))
for _ in range(self.local_crops_number):
crops.append(self.local_transfo(image))
return crops
Also worth noticing is the crop parameters:
# Multi-crop parameters
parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
help="""Scale range of the cropped image before resizing, relatively to the origin image.
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
local views to generate. Set this parameter to 0 to disable multi-crop training.
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
help="""Scale range of the cropped image before resizing, relatively to the origin image.
Used for small local view cropping of multi-crop.""")
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp,
teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
))
The output of the teacher network is centered with a mean computed over the batch. Each networks outputs a K dimensional feature that is normalized with a temperature softmax over the feature dimension. Their similarity is then measured with a cross-entropy loss.
This happens in these lines (remember for each image we have two global transformations and default 8 local transformations):
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
self.center
is also updated using EMA in these lines: def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
dist.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
# ema update
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
We study the complementarity role of centering and target sharpening to avoid collapse. There are two forms of collapse: regardless of the input, the model output is uniform along all the dimensions or dominated by one dimension. The centering avoids the collapse induced by a dominant dimension, but encourages an uniform output. Sharpening induces the opposite effect. We show this complementarity by decomposing the cross-entropy
H
into an entropyh
and the Kullback-Leibler divergence (“KL”)
TODO: Understand visualize_attention.py
here.
vit