# 2. DEIT (Data Efficient Image Transformers) - Trained on ImageNet only (1.2 million images) - Competitive performance on ImageNet (84.4%) - In addition to CLS token, also added distillation token => responsible for predicting output of a CNN model (RegNetY-16GF) - To make it work on a smaller amount of data, three techniques were used: - data augmentation: repeated augmentation, auto-augment, rand-augment, random erasing, mixup, cutmix - optimization - regularization: trained on (224,224) then finetuned on (384,384) (Question: HOW???) - it was ensured that the L2 norm of enlarged patches was the same as the L2 norm of regular patches ### Implementation #### Model - `DistilledVisionTransformer` is defined [here](https://github.com/facebookresearch/deit/blob/53c0a07cae34baf6ea994e18a67e59d871fbf332/models.py#L20) It simply adds a [self.dist_token](https://github.com/facebookresearch/deit/blob/53c0a07cae34baf6ea994e18a67e59d871fbf332/models.py#L23) and updates the number of positional embeddings to [2 + num_patches](https://github.com/facebookresearch/deit/blob/53c0a07cae34baf6ea994e18a67e59d871fbf332/models.py#L25). - [2D image to patch embedding](https://github.com/rwightman/pytorch-image-models/blob/7c67d6aca992f039eece0af5f7c29a43d48c00e4/timm/models/layers/patch_embed.py#L15) - An image `(B,C,H,W)` in transformed into a tensor of size `(B,2+num_patches,embed_dim)` - We only care about the CLS token and distillation token outputs as shown [here](https://github.com/facebookresearch/deit/blob/53c0a07cae34baf6ea994e18a67e59d871fbf332/models.py#L49). - During inference, we take the mean of both tokens as shown [here](https://github.com/facebookresearch/deit/blob/53c0a07cae34baf6ea994e18a67e59d871fbf332/models.py#L59). - A `Block` is defined [here](https://github.com/rwightman/pytorch-image-models/blob/7c67d6aca992f039eece0af5f7c29a43d48c00e4/timm/models/vision_transformer.py#L232). #### Losses - Implemented in [this file](https://github.com/facebookresearch/deit/blob/53c0a07cae34baf6ea994e18a67e59d871fbf332/losses.py): pretty straight forward - weighted loss of `base_criterion` and `distillation loss`. ###### tags: `vit`