# 3. DINO (self-`DI`stillation with `NO` labels) - 80.1% top-1 with linear evaluation of ViT base - Self-supervised without contrastive loss. Distillation model where the teacher model is an EMA of the student model. ### Code #### Data Augmentation There are two kinds of crops - global (used by teacher) and local. Student model uses both global and local crops. [This class](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L419) defined the augmentations. - [Random horizontal flip, color jitter and grayscale](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L421-L428) - [Two kinds of global crops and one kind of local crop](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L435-L456) - For each image, we get two global crops (one of each kind) and `local_crops_number` (8 by default) number of local crops [here](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L460-L464): ```python def __call__(self, image): crops = [] crops.append(self.global_transfo1(image)) crops.append(self.global_transfo2(image)) for _ in range(self.local_crops_number): crops.append(self.local_transfo(image)) return crops ``` Also worth noticing is the [crop parameters](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L107-L117): ```python # Multi-crop parameters parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.), help="""Scale range of the cropped image before resizing, relatively to the origin image. Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""") parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small local views to generate. Set this parameter to 0 to disable multi-crop training. When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), help="""Scale range of the cropped image before resizing, relatively to the origin image. Used for small local view cropping of multi-crop.""") ``` #### Loss function - Implemented in [this class](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L363) - Outputs from student and teacher are passed to a softmax layer. The temperature for the student softmax is 0.1 whereas for the teacher is pretty high. However at the beginning of the training the teacher temperature is kept low and increased gradually. See [this comment](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L372-L373). - Teacher temperature schedule is defined [here](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L374-L378): ```python self.teacher_temp_schedule = np.concatenate(( np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp )) ``` - From the paper: > The output of the teacher network is centered with a mean computed over the batch. Each networks outputs a K dimensional feature that is normalized with a temperature softmax over the feature dimension. Their similarity is then measured with a cross-entropy loss. This happens in [these lines](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L384-L390) (remember for each image we have two global transformations and default 8 local transformations): ```python student_out = student_output / self.student_temp student_out = student_out.chunk(self.ncrops) # teacher centering and sharpening temp = self.teacher_temp_schedule[epoch] teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) teacher_out = teacher_out.detach().chunk(2) ``` - `self.center` is also updated using EMA [in these lines](https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L407-L416): ```python def update_center(self, teacher_output): """ Update center used for teacher output. """ batch_center = torch.sum(teacher_output, dim=0, keepdim=True) dist.all_reduce(batch_center) batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) # ema update self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) ``` - Section 5.3 on avoiding collapse: > We study the complementarity role of centering and target sharpening to avoid collapse. There are two forms of collapse: regardless of the input, the model output is uniform along all the dimensions or dominated by one dimension. The centering avoids the collapse induced by a dominant dimension, but encourages an uniform output. Sharpening induces the opposite effect. We show this complementarity by decomposing the cross-entropy `H` into an entropy `h` and the Kullback-Leibler divergence (“KL”) TODO: Understand `visualize_attention.py` [here](https://github.com/facebookresearch/dino/blob/94175993abde84179449d79e22eab7ea28dec14b/visualize_attention.py). ###### tags: `vit`