# 4. BeIT (Bert Encoding representations for Image Transformers) - Another self-supervised image representation model > Key idea: During pre-training, we randomly mask some proportion of image patches, and feed the corrupted input to Transformer. The model learns to recover the visual tokens of the original image, instead of the raw pixels of masked patches. - Image patches are flattened into vectors and linearly projected. > In our experiments, we split each 224 × 224 image into a 14 × 14 grid of image patches, where each patch is 16 × 16. - An image is broken down into patches. Each patch has two representations: patch representation via transformer encoder and tokens from a fixed vocabulary. ### Code #### Model (taken from [here](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/run_beit_pretraining.py)) - BEIT model is created using `timm` ```python from timm.models import create_model # '--model', default='deit_base_patch16_224' def get_model(args): print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, drop_path_rate=args.drop_path, drop_block_rate=None, use_shared_rel_pos_bias=args.rel_pos_bias, use_abs_pos_emb=args.abs_pos_emb, init_values=args.layer_scale_init_value, ) return model model = get_model(args) ``` - dVAE is based on DALLe or a customized dVAE ([code here](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/utils.py#L480-L512)) ```python def create_d_vae(weight_path, d_vae_type, image_size, device): if d_vae_type == "dall-e": return get_dalle_vae(weight_path, image_size, device) elif d_vae_type == "customized": return get_d_vae(weight_path, image_size, device) else: raise NotImplementedError() def get_dalle_vae(weight_path, image_size, device): vae = Dalle_VAE(image_size) vae.load_model(model_dir=weight_path, device=device) return vae def get_d_vae(weight_path, image_size, device): NUM_TOKENS = 8192 NUM_LAYERS = 3 EMB_DIM = 512 HID_DIM = 256 state_dict = torch.load(os.path.join(weight_path, "pytorch_model.bin"), map_location="cpu")["weights"] model = DiscreteVAE( image_size=image_size, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, ).to(device) model.load_state_dict(state_dict) return model ``` TODO: Understand how `DALLE_VAE` and `DiscreteVAE` are defined [in this file](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/modeling_discrete_vae.py). #### Data - [This class](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/masking_generator.py#L29) generates masks randomly. - This is then used in [data augmentation for BEIT](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/datasets.py#L71-L75) along with other transforms. - Another noteworthy transform is [RandomResizedCropAndInterpolationWithTwoPic](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/transforms.py#L67). It returns two transformed images, one for patches and one for tokens. ```python def build_beit_pretraining_dataset(args): transform = DataAugmentationForBEiT(args) print("Data Aug = %s" % str(transform)) return ImageFolder(args.data_path, transform=transform) ``` #### Loss - [This](https://github.com/microsoft/unilm/blob/f18eda3856ba16abee622f3a839cd3246183216c/beit/engine_for_pretraining.py#L49-L52) is how the labels for the prediction are created using dVAE tokens: ```python with torch.no_grad(): input_ids = d_vae.get_codebook_indices(images).flatten(1) bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool) labels = input_ids[bool_masked_pos] with torch.cuda.amp.autocast(): outputs = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False) loss = nn.CrossEntropyLoss()(input=outputs, target=labels) ``` Questions: - Where is the head defined (which maps BEIT outputs to logits)? - How to extract attention for different objects as shown in the paper? ###### tags: `vit`