{
blocks.2.mlp_norm,
blocks.5.mlp_norm,
blocks.8.mlp_norm
}
{
blocks.2.attn.proj,
blocks.5.attn.proj,
blocks.8.attn.proj
}
In the previous setup, we split the model into 4 stages across 4 devices. Now, let's try splitting the model into more stages or less stages. Please compare the speed up between 2, 3, 4, and 6-stage pipeline.
Methodologies:
https://github.com/VainF/Torch-Pruning/blob/master/examples/transformers/prune_timm_vit.py
Hyperparameters:
BATCH_SIZE = 16
LEARNING_RATE = 5e-4
NUM_EPOCH = 15
1. Re-implement the forwarding
# Here we re-implement the forward function of timm.models.vision_transformer.Attention
# as the original forward function requires the input and output channels to be identical.
def forward(self, x):
"""https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/vision_transformer.py#L79"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
2. Prepare a pruner:
imp = tp.importance.GroupHessianImportance()
example_image = next(iter(train_loader))[0][0]
ignored_layers = [model.head]
for m in model.modules():
if isinstance(m, timm.models.vision_transformer.Attention):
m.forward = forward.__get__(m, timm.models.vision_transformer.Attention)
num_heads[m.qkv] = m.num_heads
pruner = tp.pruner.MetaPruner(
model,
example_image,
global_pruning=True, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
pruning_ratio=0.15, # target pruning ratio
ignored_layers=ignored_layers,
num_heads=num_heads, # number of heads in self attention
prune_num_heads=False, # reduce num_heads by pruning entire heads (default: False)
prune_head_dims=not True, # reduce head_dim by pruning featrues dims of each head (default: True)
head_pruning_ratio=0.0, #args.head_pruning_ratio, # remove 50% heads, only works when prune_num_heads=True (default: 0.0)
round_to=2
)
3. Prune model
if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
model.zero_grad()
if isinstance(imp, tp.importance.GroupHessianImportance):
imp.zero_grad()
print("Accumulating gradients for pruning...")
for k, (imgs, lbls) in enumerate(train_loader):
if k >= 10: break
imgs = imgs.to(device)
lbls = lbls.to(device)
output = model(imgs)
if isinstance(imp, tp.importance.GroupHessianImportance):
loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
for l in loss:
model.zero_grad()
l.backward(retain_graph=True)
imp.accumulate_grad(model)
elif isinstance(imp, tp.importance.GroupTaylorImportance):
loss = torch.nn.functional.cross_entropy(output, lbls)
loss.backward()
for i, g in enumerate(pruner.step(interactive=True)):
g.prune()
4. Fine-tuning:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, NUM_EPOCH)
criterion = nn.CrossEntropyLoss()
best_acc = 0
for epoch_num in tqdm(range(1, NUM_EPOCH + 1)):
train_one_epoch(model, criterion, optimizer, train_loader, device, scheduler)
acc = evaluate_model(model, test_loader, device)
print(f"epoch {epoch_num}:", acc)
if acc > best_acc:
torch.save(model, f"prune_iter.pth")
best_acc = acc
Pruned model accuracy test