- Trained on ImageNet only (1.2 million images)
- Competitive performance on ImageNet (84.4%)
- In addition to CLS token, also added distillation token => responsible for predicting output of a CNN model (RegNetY-16GF)
- To make it work on a smaller amount of data, three techniques were used:
- data augmentation: repeated augmentation, auto-augment, rand-augment, random erasing, mixup, cutmix
- optimization
- regularization: trained on (224,224) then finetuned on (384,384) (Question: HOW???)
- it was ensured that the L2 norm of enlarged patches was the same as the L2 norm of regular patches
Implementation
Model
DistilledVisionTransformer
is defined here
It simply adds a self.dist_token and updates the number of positional embeddings to 2 + num_patches.
-
2D image to patch embedding
-
An image (B,C,H,W)
in transformed into a tensor of size (B,2+num_patches,embed_dim)
-
We only care about the CLS token and distillation token outputs as shown here.
-
During inference, we take the mean of both tokens as shown here.
-
A Block
is defined here.
Losses
- Implemented in this file: pretty straight forward - weighted loss of
base_criterion
and distillation loss
.