Key idea: During pre-training, we randomly mask some proportion of image patches, and feed the corrupted input to Transformer. The model learns to recover the visual tokens of the original image, instead of the raw pixels of masked patches.
In our experiments, we split each 224 × 224 image into a 14 × 14 grid of image patches, where each patch is 16 × 16.
timm
from timm.models import create_model
# '--model', default='deit_base_patch16_224'
def get_model(args):
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=False,
drop_path_rate=args.drop_path,
drop_block_rate=None,
use_shared_rel_pos_bias=args.rel_pos_bias,
use_abs_pos_emb=args.abs_pos_emb,
init_values=args.layer_scale_init_value,
)
return model
model = get_model(args)
def create_d_vae(weight_path, d_vae_type, image_size, device):
if d_vae_type == "dall-e":
return get_dalle_vae(weight_path, image_size, device)
elif d_vae_type == "customized":
return get_d_vae(weight_path, image_size, device)
else:
raise NotImplementedError()
def get_dalle_vae(weight_path, image_size, device):
vae = Dalle_VAE(image_size)
vae.load_model(model_dir=weight_path, device=device)
return vae
def get_d_vae(weight_path, image_size, device):
NUM_TOKENS = 8192
NUM_LAYERS = 3
EMB_DIM = 512
HID_DIM = 256
state_dict = torch.load(os.path.join(weight_path, "pytorch_model.bin"), map_location="cpu")["weights"]
model = DiscreteVAE(
image_size=image_size,
num_layers=NUM_LAYERS,
num_tokens=NUM_TOKENS,
codebook_dim=EMB_DIM,
hidden_dim=HID_DIM,
).to(device)
model.load_state_dict(state_dict)
return model
TODO: Understand how DALLE_VAE
and DiscreteVAE
are defined in this file.
def build_beit_pretraining_dataset(args):
transform = DataAugmentationForBEiT(args)
print("Data Aug = %s" % str(transform))
return ImageFolder(args.data_path, transform=transform)
with torch.no_grad():
input_ids = d_vae.get_codebook_indices(images).flatten(1)
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
labels = input_ids[bool_masked_pos]
with torch.cuda.amp.autocast():
outputs = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False)
loss = nn.CrossEntropyLoss()(input=outputs, target=labels)
Questions:
vit