Try   HackMD

4. BeIT (Bert Encoding representations for Image Transformers)

  • Another self-supervised image representation model

Key idea: During pre-training, we randomly mask some proportion of image patches, and feed the corrupted input to Transformer. The model learns to recover the visual tokens of the original image, instead of the raw pixels of masked patches.

  • Image patches are flattened into vectors and linearly projected.

In our experiments, we split each 224 × 224 image into a 14 × 14 grid of image patches, where each patch is 16 × 16.

  • An image is broken down into patches. Each patch has two representations: patch representation via transformer encoder and tokens from a fixed vocabulary.

Code

Model (taken from here)

  • BEIT model is created using timm
from timm.models import create_model

# '--model', default='deit_base_patch16_224'
def get_model(args):
    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
        use_shared_rel_pos_bias=args.rel_pos_bias,
        use_abs_pos_emb=args.abs_pos_emb,
        init_values=args.layer_scale_init_value,
    )
    return model

model = get_model(args)
  • dVAE is based on DALLe or a customized dVAE (code here)
def create_d_vae(weight_path, d_vae_type, image_size, device):
    if d_vae_type == "dall-e":
        return get_dalle_vae(weight_path, image_size, device)
    elif d_vae_type == "customized":
        return get_d_vae(weight_path, image_size, device)
    else:
        raise NotImplementedError()


def get_dalle_vae(weight_path, image_size, device):
    vae = Dalle_VAE(image_size)
    vae.load_model(model_dir=weight_path, device=device)
    return vae


def get_d_vae(weight_path, image_size, device):
    NUM_TOKENS = 8192
    NUM_LAYERS = 3
    EMB_DIM = 512
    HID_DIM = 256

    state_dict = torch.load(os.path.join(weight_path, "pytorch_model.bin"), map_location="cpu")["weights"]

    model = DiscreteVAE(
        image_size=image_size,
        num_layers=NUM_LAYERS,
        num_tokens=NUM_TOKENS,
        codebook_dim=EMB_DIM,
        hidden_dim=HID_DIM,
    ).to(device)

    model.load_state_dict(state_dict)
    return model

TODO: Understand how DALLE_VAE and DiscreteVAE are defined in this file.

Data

def build_beit_pretraining_dataset(args):
    transform = DataAugmentationForBEiT(args)
    print("Data Aug = %s" % str(transform))
    return ImageFolder(args.data_path, transform=transform)

Loss

  • This is how the labels for the prediction are created using dVAE tokens:
with torch.no_grad():
    input_ids = d_vae.get_codebook_indices(images).flatten(1)
    bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
    labels = input_ids[bool_masked_pos]

with torch.cuda.amp.autocast():
    outputs = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False)
    loss = nn.CrossEntropyLoss()(input=outputs, target=labels)

Questions:

  • Where is the head defined (which maps BEIT outputs to logits)?
  • How to extract attention for different objects as shown in the paper?
tags: vit