--- title: Is incorporating attention rollout beneficial for pooling in Vision transformers? --- # Is incorporating attention rollout beneficial for pooling in Vision transformers? In this work, I trained two Vision Transformer (ViT) models from scratch on the CIFAR-100 dataset. The first is a baseline ViT using standard self-attention, and the second is a modified ViT that integrates attention rollout into its forward pass as described by Abnar and Zuidema (2020). As per my understanding, Vision Transformers treat an image as a sequence of patch embeddings and use a special classification token ([CLS]) to aggregate information from all patches via the self-attention mechanism. The output embedding of this [CLS] token is passed to a classifier to predict the image’s class. In the modified model, I incorporated attention rollout, which multiplies the attention matrices of all transformer layers (after adding identity matrices to account for residual connections) to compute the overall influence of each input patch on the final output. This rolled-out attention is used to derive the image representation instead of relying solely on the [CLS] token. ### Key Objective and Steps for the first leg of this work: * **Model Definitions**– Define a baseline ViT architecture and a modified ViT that uses attention rollout in its forward pass. * **Training Setup** – Prepare CIFAR-100 data and train both models from scratch on all 100 classes, using a lightweight configuration suitable for limited hardware (I used Google Colab free version - T4 GPU). * **Optimization** – Use standard PyTorch training loops with appropriate optimizations (like smaller model size and batch size, and optional mixed precision) to accommodate hardware constraints. * **Evaluation**– Provide a script to load the trained models and compute cosine similarities between the learned representations of sample images from four test classes (e.g. boy, man, table, chair) on a custom test image, demonstrating how to extract and compare feature embeddings. <!-- # Preliminary Implementation & Results ## Setup and Data Preparation First, I did set up the environment and load the CIFAR-100 dataset using torchvision. I used PyTorch for model implementation and training. I also moved models and data to colab free T4 GPU. I also applied basic data preprocessing: converted images to tensors and normalized to CIFAR-100’s mean and standard deviation. My understanding is minor data augmentation like random cropping and horizontal flips can improve generalization on CIFAR-100, but I kept it simple for clarity. The code below shows the setup: ``` import torch import torchvision import torchvision.transforms as transforms device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") print("Using device:", device) train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), # random crop with padding (augmentation) transforms.RandomHorizontalFlip(), # random horizontal flip (augmentation) transforms.ToTensor(), transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), # CIFAR-100 mean and std std=(0.2673, 0.2564, 0.2762)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)) ]) train_dataset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transform) test_dataset = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=test_transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) ``` **Note**: I chose a batch size of 64 to fit in limited GPU memory - since it is google colab free version. I can adjust this (e.g., to 32 or 128) based on the hardware capacity later (once i get access to university GPU server i guess). The normalization values are the per-channel mean and std for CIFAR-100 images. I applied simple augmentations (random crop/flip) on training data to help the models generalize. ## Baseline ViT Model Definition The baseline model follows the standard ViT architecture. Each image is split into patches, which are linearly projected into an embedding space. A learned [CLS] token embedding is prepended to the sequence of patch embeddings at each forward pass. I added learned positional embeddings to retain patch position information. The model consists of a stack of Transformer encoder blocks (self-attention + feed-forward layers with residual connections). The [CLS] token attends to all patch embeddings through these layers, effectively learning a global representation of the image. After the last encoder block, the output embedding of the [CLS] token is fed into a linear classification head to predict the class label. Below is the implementation of a minimal ViT block and the baseline ViT model. I used a smaller configuration suitable for CIFAR-100 and limited hardware: patch size 4 (so 8 * 8 = 64 patches for a 32 x 32 image), embedding dimension 128, 4 attention heads, and 6 transformer layers. LayerNorm is applied before each attention and MLP (Pre-LN architecture), and i also included dropout for regularization. This model is intentionally kept small to ensure training is feasible on a free Colab T4 GPU. ``` import torch.nn as nn import math class ViTBlock(nn.Module): """Transformer encoder block: LayerNorm -> Multi-head Self-Attention -> Add & Norm -> MLP -> Add & Norm.""" def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) self.drop1 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(embed_dim) # MLP consists of two linear layers with GELU non-linearity hidden_dim = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, embed_dim), nn.Dropout(dropout), ) def forward(self, x, return_attn=False): # Self-attention with pre-LayerNorm x_norm = self.norm1(x) # Query, Key, Value are all x_norm (self-attention) attn_out, attn_weights = self.attn(x_norm, x_norm, x_norm, need_weights=return_attn, average_attn_weights=not return_attn) x = x + self.drop1(attn_out) # Feed-forward network with pre-LayerNorm x_norm2 = self.norm2(x) mlp_out = self.mlp(x_norm2) x = x + mlp_out if return_attn: # attn_weights shape: (batch, num_heads, seq_len, seq_len) return x, attn_weights return x class ViTBaseline(nn.Module): def __init__(self, image_size=32, patch_size=4, num_classes=100, embed_dim=128, depth=6, num_heads=4, dropout=0.1): super().__init__() assert image_size % patch_size == 0, "Image size must be divisible by patch size" self.patch_size = patch_size num_patches = (image_size // patch_size) ** 2 # e.g., (32/4)^2 = 64 patches self.embed_dim = embed_dim # Patch embedding: project each patch to embed_dim (using a Conv layer for efficiency) self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) # Class token and positional embedding self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(dropout) # Transformer encoder blocks self.blocks = nn.ModuleList([ViTBlock(embed_dim, num_heads, mlp_ratio=4.0, dropout=dropout) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) # Classification head self.head = nn.Linear(embed_dim, num_classes) # Initialize weights nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): B = x.shape[0] # Patch embedding: flatten image to patches and project to embed_dim patch_embeddings = self.patch_embed(x) # Rearrange to (B, num_patches, embed_dim) patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) # Prepare class token and add to patch embeddings cls_tokens = self.cls_token.expand(B, -1, -1) x_seq = torch.cat([cls_tokens, patch_embeddings], dim=1) # Add positional embeddings x_seq = x_seq + self.pos_embed[:, :x_seq.size(1), :] x_seq = self.pos_drop(x_seq) # Pass through Transformer encoder blocks for block in self.blocks: x_seq = block(x_seq, return_attn=False) # Apply final LayerNorm x_seq = self.norm(x_seq) # Class token output cls_out = x_seq[:, 0] logits = self.head(cls_out) return logits def forward_features(self, x): """Utility method to get the feature vector (CLS embedding) before the classification head.""" # Similar to forward, but returns the CLS embedding instead of class scores B = x.shape[0] patch_embeddings = self.patch_embed(x).flatten(2).transpose(1, 2) cls_tokens = self.cls_token.expand(B, -1, -1) x_seq = torch.cat([cls_tokens, patch_embeddings], dim=1) x_seq = x_seq + self.pos_embed[:, :x_seq.size(1), :] x_seq = self.pos_drop(x_seq) for block in self.blocks: x_seq = block(x_seq, return_attn=False) x_seq = self.norm(x_seq) cls_emb = x_seq[:, 0] # final CLS embedding return cls_emb ``` In this code, ViTBlock implements one transformer encoder layer with multi-head self-attention and a two-layer MLP. I used nn.MultiheadAttention for convenience, with batch_first=True so that inputs are of shape (batch, seq_len, embed_dim). The block returns attention weights only if return_attn=True (this will be used in the rollout model). The ViTBaseline class builds the patch embedding layer (using a convolution to split and project patches), prepends the cls_token, adds positional encodings, and then applies a sequence of ViTBlocks. The forward method returns class logits, while forward_features returns the final [CLS] token embedding (useful for extracting features without the classification layer, e.g., for similarity computations later). ## Modified ViT model with Attention Rollout The modified model, ViTRollout, uses the same building blocks (patch embedding, [CLS] token, transformer layers) but integrates the attention rollout computation into its forward pass. Attention rollout (Abnar & Zuidema, 2020) is a post-hoc technique to quantify the influence of each input token on the output by propagating attention through the layers (Check - https://jacobgil.github.io/deeplearning/vision-transformer-explainability). In practice, we compute the rollout by: (a) extracting the attention weight matrices from each transformer block, (b) adding an identity matrix to each to account for residual connections (so each token retains some of its own information), \(c) averaging across multiple attention heads at each layer, and (d) multiplying these modified attention matrices together, layer by layer. The result is an attention flow matrix from the input tokens to the output tokens. We normalize the attention at each step so that each token’s attention weights sum to 1, maintaining it as a proper probability distribution through the layers. Based on my reading about vision transformers (https://paperswithcode.com/method/vision-transformer), in a ViT with a [CLS] token, the first row of the final rollout matrix gives the overall attention weight that the final [CLS] output receives from each initial token (the [CLS] itself and all patches). For our modified model’s output, we ignore the contribution of the initial [CLS] token (which carries no image information) and focus on the patch contributions. Essentially, we obtain a weight for each image patch indicating its importance to the final classification. We then compute a weighted sum of the final layer patch embeddings using these rollout weights to produce a representation vector for the image, and feed that to the classifier. This replaces the direct use of the final [CLS] embedding. By doing so, the model’s prediction is explicitly based on an aggregation of patch features weighted by the multi-layer attention flow originating from each patch. This integration provides an interpretable mechanism, since the weights indicate which patches (image regions) were most influential, and it slightly alters how the model learns to pool information from patches. Below is the implementation of ViTRollout. It reuses the same transformer blocks but collects attention weights at each layer to perform the rollout computation: ``` class ViTRollout(nn.Module): def __init__(self, image_size=32, patch_size=4, num_classes=100, embed_dim=128, depth=6, num_heads=4, dropout=0.1): super().__init__() assert image_size % patch_size == 0 num_patches = (image_size // patch_size) ** 2 # Patch embedding, class token, pos embedding (same as baseline) self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(dropout) # Transformer blocks (will retrieve attention weights) self.blocks = nn.ModuleList([ViTBlock(embed_dim, num_heads, mlp_ratio=4.0, dropout=dropout) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) # Initialize weights nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): B = x.shape[0] # Embedding + CLS token + positions (same as baseline) patch_embeddings = self.patch_embed(x).flatten(2).transpose(1, 2) # (B, num_patches, embed_dim) cls_tokens = self.cls_token.expand(B, 1, -1) # (B, 1, embed_dim) x_seq = torch.cat([cls_tokens, patch_embeddings], dim=1) # (B, 1+num_patches, embed_dim) x_seq = x_seq + self.pos_embed[:, :x_seq.size(1), :] x_seq = self.pos_drop(x_seq) # Forward through transformer layers, collect attention from each attn_mats = [] # to store attention matrices for rollout for block in self.blocks: x_seq, attn_weights = block(x_seq, return_attn=True) # get attention weights per layer attn_mats.append(attn_weights) # Apply final LayerNorm to outputs x_seq = self.norm(x_seq) # shape: (B, 1+num_patches, embed_dim) # Compute Attention Rollout # Start with identity matrix for shape (seq_len x seq_len) seq_len = x_seq.size(1) # this is 1 + num_patches I = torch.eye(seq_len, device=x_seq.device).expand(B, seq_len, seq_len) # (B, seq_len, seq_len) # Initialize rollout matrix as identity rollout = I for attn_weights in attn_mats: # attn_weights shape: (B, num_heads, seq_len, seq_len) # Average heads A = attn_weights.mean(dim=1) # (B, seq_len, seq_len) A_hat = A + I # add identity for residual connection # Normalize each row of A_hat row_sum = A_hat.sum(dim=2, keepdim=True) # sum over source tokens A_hat = A_hat / row_sum # Multiply into rollout matrix rollout = torch.bmm(A_hat, rollout) # batch matrix multiplication # Now rollout (B, seq_len, seq_len) contains overall influence of each initial token on each final token. # Use rollout for weighted patch aggregation: # We take the first row (final CLS token) attention distribution over initial tokens: # rollout[:, 0, :] is attention from each initial token to the final CLS. patch_weights = rollout[:, 0, 1:] # ignore index 0 (initial CLS itself), shape: (B, num_patches) # Normalize the patch weights to sum to 1 (excluding CLS token) patch_weights = patch_weights / patch_weights.sum(dim=1, keepdim=True) # Compute weighted sum of **final-layer patch embeddings** using these weights patch_embs = x_seq[:, 1:, :] # final normalized embeddings of patches, shape: (B, num_patches, embed_dim) # Weight the patches by patch_weights and sum: rep = (patch_weights.unsqueeze(-1) * patch_embs).sum(dim=1) # shape: (B, embed_dim) # Classification head on the representation logits = self.head(rep) # (B, num_classes) return logits def forward_features(self, x): # Similar to forward, but returns the representation vector (rep) before classification B = x.shape[0] patch_embeddings = self.patch_embed(x).flatten(2).transpose(1, 2) cls_tokens = self.cls_token.expand(B, 1, -1) x_seq = torch.cat([cls_tokens, patch_embeddings], dim=1) x_seq = x_seq + self.pos_embed[:, :x_seq.size(1), :] x_seq = self.pos_drop(x_seq) attn_mats = [] for block in self.blocks: x_seq, attn_weights = block(x_seq, return_attn=True) attn_mats.append(attn_weights) x_seq = self.norm(x_seq) seq_len = x_seq.size(1) I = torch.eye(seq_len, device=x_seq.device).expand(B, seq_len, seq_len) rollout = I for attn_weights in attn_mats: A = attn_weights.mean(dim=1) A_hat = A + I A_hat = A_hat / A_hat.sum(dim=2, keepdim=True) rollout = torch.bmm(A_hat, rollout) patch_weights = rollout[:, 0, 1:] patch_weights = patch_weights / patch_weights.sum(dim=1, keepdim=True) patch_embs = x_seq[:, 1:, :] rep = (patch_weights.unsqueeze(-1) * patch_embs).sum(dim=1) # final representation vector return rep ``` Below is the break down the key parts of ViTRollout.forward: * I created the patch embeddings and prepend the class token just like in the baseline. Then, we forward through each ViTBlock with return_attn=True to collect the attention weight matrices from each layer. Each attn_weights has shape (B, heads, seq_len, seq_len), where seq_len = 1 + num_patches (including the class token). We store these in attn_mats. * After obtaining all layers’ attention, we compute the rollout matrix. We start with an identity matrix of shape (seq_len, seq_len) representing initial direct influence (each token 100% influences itself). For each layer’s attention matrix A (averaged over heads), we add identity to get $A^{\prime} = A + I$, normalize each row of $A^{\prime}$, and then multiply it with the current rollout matrix. This effectively composes the attentions across layers: after processing all layers, rollout gives the cumulative attention from any initial token $j$ to any final token $i$. In particular, rollout[:, 0, :] is a vector of how much each initial token contributes to the final class token (index 0 in the output). * We exclude the [CLS] token’s self-contribution (rollout[:,0,0]) and take the rest rollout[:,0,1:] as the attention distribution over the image patches. These patch weights are normalized to sum to 1. Then we obtain the final patch embeddings (patch_embs) from x_seq[:,1:,:] (after the last LayerNorm, i.e., the final state of each patch token). We compute a weighted sum of these patch embeddings using the attention rollout weights. The resulting rep vector (dimension 128) is our attention-rollout aggregated representation of the image. This is passed to the linear head to produce class logits. The forward_features method similarly returns the representation vector rep before the classification layer. With this design, the modified model’s predictions are based on an explicit aggregation of patch features, with weights determined by the model’s entire attention structure. This gives an interpretable output (you could inspect patch_weights to see which patches were most important for a given prediction). ## Training the models from scratch I trained both models on CIFAR-100, and used cross-entropy loss for the 100-class classification and the Adam optimizer (with a moderate weight decay, since Transformers often benefit from AdamW to regularize weights). I trained for 50 epochs. ``` import torch.nn.functional as F baseline_model = ViTBaseline().to(device) rollout_model = ViTRollout().to(device) criterion = nn.CrossEntropyLoss() optim_baseline = torch.optim.AdamW(baseline_model.parameters(), lr=1e-3, weight_decay=1e-4) optim_rollout = torch.optim.AdamW(rollout_model.parameters(), lr=1e-3, weight_decay=1e-4) num_epochs = 50 def train_one_model(model, optimizer, name="Model"): model.train() for epoch in range(1, num_epochs+1): running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() avg_loss = running_loss / len(train_loader) if epoch % 10 == 0: print(f"{name} Epoch [{epoch}/{num_epochs}] - Average Loss: {avg_loss:.4f}") print("Training Baseline ViT...") train_one_model(baseline_model, optim_baseline, name="BaselineViT") torch.save(baseline_model.state_dict(), "vit_baseline_cifar100.pth") print("Training Rollout ViT...") train_one_model(rollout_model, optim_rollout, name="RolloutViT") torch.save(rollout_model.state_dict(), "vit_rollout_cifar100.pth") ``` ``` Training Baseline ViT... BaselineViT Epoch [10/50] - Average Loss: 2.6091 BaselineViT Epoch [20/50] - Average Loss: 2.0749 BaselineViT Epoch [30/50] - Average Loss: 1.6643 BaselineViT Epoch [40/50] - Average Loss: 1.3389 BaselineViT Epoch [50/50] - Average Loss: 1.1102 Training Rollout ViT... RolloutViT Epoch [10/50] - Average Loss: 2.3617 RolloutViT Epoch [20/50] - Average Loss: 1.7240 RolloutViT Epoch [30/50] - Average Loss: 1.2512 RolloutViT Epoch [40/50] - Average Loss: 0.9200 RolloutViT Epoch [50/50] - Average Loss: 0.6847 ``` This script trains each model for 50 epochs (printing the average loss every 10 epochs as a simple progress indicator). On a single GPU, 50 epochs with these small models and CIFAR-100 (50k training images) is reasonable. ## Computing the Cosine Similarity The below code loads two ViT models (baseline and rollout), extracts and normalizes feature embeddings for the “man,” “boy,” “table,” and “chair” test images, computes their pairwise cosine similarity matrices, and prints them. ``` baseline_model = ViTBaseline().to(device) baseline_model.load_state_dict(torch.load("vit_baseline_cifar100.pth", map_location=device)) baseline_model.eval() rollout_model = ViTRollout().to(device) rollout_model.load_state_dict(torch.load("vit_rollout_cifar100.pth", map_location=device)) rollout_model.eval() classes_of_interest = ["man", "boy", "table", "chair"] class_to_idx = test_dataset.class_to_idx indices = [class_to_idx[c] for c in classes_of_interest] test_examples = [] example_labels = [] for img, label in test_dataset: if label in indices: if label not in example_labels: test_examples.append(img) example_labels.append(label) if len(test_examples) == len(indices): break ordered_examples = [] for cls_label in indices: idx = example_labels.index(cls_label) ordered_examples.append(test_examples[idx]) ordered_examples = torch.stack(ordered_examples).to(device) with torch.no_grad(): feats_base = baseline_model.forward_features(ordered_examples) feats_roll = rollout_model.forward_features(ordered_examples) # shape (4, embed_dim) # Normalize the feature vectors feats_base = F.normalize(feats_base, p=2, dim=1) feats_roll = F.normalize(feats_roll, p=2, dim=1) # Compute cosine similarity matrix for each model cos_sim_base = feats_base @ feats_base.T # 4x4 matrix of cosine sims cos_sim_roll = feats_roll @ feats_roll.T # 4x4 matrix for rollout model # Print cosine similarities for each pair of classes print("Cosine similarity (Baseline ViT) for classes:") for i, ci in enumerate(classes_of_interest): for j, cj in enumerate(classes_of_interest): print(f" {ci:6s} vs {cj:6s}: {cos_sim_base[i,j].item():.3f}") print("\nCosine similarity (Rollout ViT) for classes:") for i, ci in enumerate(classes_of_interest): for j, cj in enumerate(classes_of_interest): print(f" {ci:6s} vs {cj:6s}: {cos_sim_roll[i,j].item():.3f}") ``` ## Understanding the results Below is the result of cosine similarities I got- ``` Cosine similarity (Baseline ViT) for classes: man vs man : 1.000 man vs boy : 0.335 man vs table : 0.016 man vs chair : -0.141 boy vs man : 0.335 boy vs boy : 1.000 boy vs table : -0.053 boy vs chair : 0.086 table vs man : 0.016 table vs boy : -0.053 table vs table : 1.000 table vs chair : 0.357 chair vs man : -0.141 chair vs boy : 0.086 chair vs table : 0.357 chair vs chair : 1.000 Cosine similarity (Rollout ViT) for classes: man vs man : 1.000 man vs boy : 0.333 man vs table : -0.031 man vs chair : -0.175 boy vs man : 0.333 boy vs boy : 1.000 boy vs table : -0.020 boy vs chair : -0.042 table vs man : -0.031 table vs boy : -0.020 table vs table : 1.000 table vs chair : 0.325 chair vs man : -0.175 chair vs boy : -0.042 chair vs table : 0.325 chair vs chair : 1.000 ``` Based on my analysis, attention‐rollout does nudge down the off-diagonal (inter-class) cosine similarities. For the six unique off-diagonal pairs (man–boy, man–table, man–chair, boy–table, boy–chair, table–chair), here’s a quick comparison: | Pair | Baseline ViT | Rollout ViT | Δ (Rollout – Baseline) | | -------------- | ------------ | ----------- | ---------------------- | | man vs boy | 0.335 | 0.333 | –0.002 | | man vs table | 0.016 | –0.031 | –0.047 | | man vs chair | –0.141 | –0.175 | –0.034 | | boy vs table | –0.053 | –0.020 | +0.033 | | boy vs chair | 0.086 | –0.042 | –0.128 | | table vs chair | 0.357 | 0.325 | –0.032 | - **Average absolute off-diagonal similarity** - **Baseline ViT:** \[(0.335 + 0.016 + 0.141 + 0.053 + 0.086 + 0.357) / 6\] ≈ **0.165** - **Rollout ViT:** \[(0.333 + 0.031 + 0.175 + 0.020 + 0.042 + 0.325) / 6\] ≈ **0.154** That’s roughly a **6% relative reduction** in average inter-class similarity. ### What this means - **Slight decorrelation:** Rollout tends to push different classes a bit further apart (lower cosine), which is what we want for better class discrimination. - **Bottom line:** attention-rollout does “help” in the sense of slightly reducing inter-class overlap, but it’s not a dramatic improvement here which can be attributed to limitations I faced while training the model. ## Things to Note 1) Colab - https://colab.research.google.com/drive/1Cwfg9MpXGsdZARgO4gUWU3yiTSY2JCOf?usp=sharing 2) Since this was a basic transformer architecture for ViT with limited hardware and training, our model didn't really achieve the state-of-the-art accuracies but I believe this could be a really good starting point. --> # Iteration 2 Previous implementation was a very preliminary work given my less experience dealing with vision transformers. Below were the noteworthy limitation of that- 1. Extremely small evaluation slice (4 classes × 1 image) → not statistically meaningful. 2. Single seed, small model (128‑d, 6 layers), and short schedule (50 epochs, no LR decay) constrain performance. 3. Cosine similarities on a handful of points can be unstable and sensitive to sampling. So in this iteration, we go for a much better implementation with a more productionised code at - https://github.com/shreejeetsahay/attention-rollout-work Let's understand what I did now. ## Goal As I mentioned, previous implementation was like a toy probe. This time I wanted to move from from a toy probe to a scalable, reproducible pipeline that cleanly compares ViT (CLS pooling) vs ViT with attention‑rollout pooling, and evaluates embeddings with class‑level metrics. ## Model Design As you can see in the vit_rollout.py file, this time I implemented a single ViT class with a mode switch- * baseline: representation = final [CLS] token. * rollout: representation = attention‑rollout–weighted sum of final‑layer patch embeddings. Rollout is computed per layer by averaging heads, adding identity (residual path), row‑normalising, then multiplying across layers to get an influence matrix. I dropped the initial [CLS] column, re‑normalise patch weights, and pool patches accordingly. Moreover, backbone was kept identical across modes (patch=4, embed=128, depth=6, heads=4, Pre‑LN, dropout=0.1) to isolate the effect of pooling. ## Data & loaders * We used CIFAR‑100 with standard aug: RandomCrop(32, padding=4) + RandomHorizontalFlip; per‑channel normalisation. * get_loaders() returns train/test loaders and the raw test_ds; I use test_ds to build class‑balanced subsets for evaluation. ## Training regimen * Optimiser: AdamW + cross‑entropy. * Schedule: linear warm‑up → cosine decay (LambdaLR), default epochs=200, lr_max=5e‑4, lr_min=5e‑6, warm‑up = 5 epochs. * AMP on CUDA via torch.cuda.amp to improve speed/VRAM. * Checkpoints saved/reused as baseline.pth and rollout.pth (skip retrain unless --retrain is set). ## Evaluation * collect_feats() builds a class‑balanced subset (configurable --classes and --per_class), forwards in batches, and concatenates features. * I L2‑normalise features before distance‑based metrics to make scales comparable. * Metrics: - **k‑NN Accuracy (k = 5, cosine)** - Classifies each embedding by its 5 nearest neighbors using cosine distance. - **Higher is better** → local neighborhoods align with labels - **Intra‑class Distance (centroid‑based)** - On **L2‑normalized** features, compute the mean Euclidean distance from samples to their class centroid; average over classes. - **Lower is better** → tighter clusters. - **Inter‑class Distance (centroid‑based, cosine)** - Compute pairwise centroid distances as 1-cos(c_i,c_j) average over all class pairs. - **Higher is better** → class means are farther apart. - **Inter / Intra Ratio** - ratio = inter/intra - **Higher is better** → separation outweighs within‑class spread. - **Silhouette Score (Euclidean)** - For each point, (b−a)/max(a,b) , where \(a\) = mean distance to points in the same class, \(b\) = mean distance to the nearest other class. - Ranges **−1 to 1**; **higher is better**. Near **0** → weak/overlapping clusters. * All metrics are cast to Python floats from torch floats and written to results.json. ## Visualization * For a chosen test image, I computed rollout patch weights, upsampled the 8×8 grid to 32×32, and overlay it on the de‑normalised image. * I saved a side‑by‑side figure (heatmap_compare.png) for Baseline (rollout visualised for reference) and Rollout (weights actually used). ## Why I went for these choices Simple reason: I am new to this topic and I took suggestions from Deep Research and incorporated them because I knew my initial work was not good. Now from what I understand- 1. The mode switch ensures the only change between models is the pooling mechanism. 1. Warm‑up + cosine stabilises early updates and improves late‑stage convergence versus a flat LR. 1. L2‑normalisation avoids misleading Euclidean scales and aligns with cosine‑based retrieval. 1. Balanced sampling and multiple metrics provide a more meaningful picture than a 4‑image cosine matrix. ## Understanding the results ### Training ![tlossc100](https://hackmd.io/_uploads/S1FpvQRwxg.png) As you can see from above figure, Rollout converges faster and lower (CE ≈ 0.009 vs 0.072). Lower loss doesn’t always translate 1‑to‑1 into k‑NN gains, but it shows the model is using its capacity well. ## Evaluation ### 10-class probe, 100 images/class So firstly, let's understand the results on 10-class probe, 100 images/class, | metric (10‑class probe, 100 images / class) | **Baseline** | **Rollout** | What it means | | ------------------------------------------- | ------------ | ----------- | ----------------------------------------------------------------------------------------- | | k‑NN @ k = 5 | **0.877** | 0.862 | Baseline wins by ≈ 1.5 pp – a small gap that could be noise. | | Inter‑class distance (cosine) | 0.869 | **0.982** | Rollout pushes class centroids farther apart (good). | | Intra‑class distance | 0.855 | 0.876 | Slightly looser clusters for rollout (expected: patch‑weighted pooling injects variance). | | Ratio (inter / intra) | 1.02 | **1.12** | Net separation improves despite the looser clusters. | | Silhouette | 0.067 | 0.062 | Virtually unchanged – both embeddings form weak but visible clusters. | A key take-away here is rollout now yields better centroid separation but a tiny drop in retrieval accuracy. #### Heatmap comparison Both maps below highlight the central fruit, but rollout’s saliency is more concentrated and symmetric, especially on the orange. That matches the higher inter‑class distance: the model is weighting the truly class‑defining region a bit more heavily. ![heatmap_compare](https://hackmd.io/_uploads/rksbOXRvlg.png) ### 100-class probe, 100 images/class So after the above probe, my thought process was, let us find the same metrics, but this time for all the 100 classes, with 100 images/class, and that is equal to full 10,000-image test split. Let us see the metrics we got. | metric (10 000 test imgs) | **Baseline** | **Rollout** | Δ (roll – base) | Interpretation | | ------------------------- | ------------ | ----------- | --------------- | --------------------------------------------------------------------------------------------------- | | k‑NN @ k = 5 | **0.616** | 0.588 | –2.8 pp | Retrieval accuracy dropped—CLS embedding still works slightly better when the label space is large. | | Intra‑class dist | 0.859 | 0.883 |  ↑ 0.024 | Rollout clusters became a bit looser (patch weighting introduces variation). | | Inter‑class dist | 0.861 | **0.956** |  ↑ 0.095 | Centroids are farther apart (good). | | Inter / Intra ratio | 1.00 | **1.08** |  ↑ 0.08 | Net separation improves despite looser clusters. | | Silhouette | 0.003 | ≈ 0 | \~0 | 100 very small clusters in 128‑D space → silhouette is near zero for both; change is negligible. | #### Interpretation on CIFAR 100 * Strategic view: * Roll‑out weighting pushes class means apart (good for a linear head or centroid‑based classifiers) but adds variance within each class, which hurts sample‑level retrieval like k‑NN. * Why the variance? * From my understanding, the patch‑weighted vector changes more from image to image—if different patches dominate in different examples of the same class, embeddings scatter. * Is it “beneficial”? * For global separability and lower classification loss, yes. * For neighbour‑based retrieval or embeddings that must stay tight which we have not achieved it. # Iteration 3 In this iteration, we went ahead and tested the above on CIFAR 10 and SVHN standard datasets. The code has been accordingly updated in repo: https://github.com/shreejeetsahay/attention-rollout-work/tree/main ## CIFAR-10 ### Training So training on CIFAR 1O, gave us a training loss curve as shown below: ![cifar10losscurve](https://hackmd.io/_uploads/SyQdpm0vll.png) | Phase | Observation | What it suggests | | ------------------------ | ------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- | | **Epoch 1-10 (warm-up)** | Roll-out drops from 1.84 → 0.91, baseline 1.87 → 1.00. Roll-out is ≈ 9 % lower at epoch 10. | Patch-weighted pooling helps the optimiser take larger effective steps from the start. | | **Epoch 10-40** | Both curves fall in parallel; the absolute gap widens (≈ 0.05 CE at E40). | Advantage persists—no sign that rollout slows gradient flow. | | **Epoch 40-120** | Gap is stable: rollout remains 10-20 % lower (e.g., 0.185 vs 0.230 at E80). | Gains are not just early luck; they last through mid-training while LR decays. | | **Epoch 120-200 (tail)** | Final losses: **0.0145** (roll-out) vs **0.0335** (baseline) → \~-56 %. | Roll-out still extracts extra margin even in low-LR regime. | So like in CIFAR 100, Roll-out converges faster and to a markedly lower loss on CIFAR-10—roughly halving the final cross-entropy. The improvement is consistent (no crossover), indicating that attention-weighted patch pooling gives the classifier head a cleaner, more linearly separable feature space right from the start. Whether that lower loss translates into better downstream metrics (accuracy, retrieval) still depends on the intra-class variance trade-off, but purely as an optimiser target roll-out is decisively easier to fit on this dataset. ### Metrics on all 10-class, 100-images probe | metric | Baseline | Roll-out | Δ | | ----------------------- | --------- | ---------- | --------------------------------- | | **Intra-class dist** | 0.7167 | **0.6937** | ↓ 0.023 (tighter clusters) | | **Inter-class dist** | 1.0003 | **1.0628** | ↑ 0.062 (centroids farther apart) | | **Inter / Intra ratio** | 1.40 | **1.53** | ↑ 0.13 | | **k-NN @ 5** | 0.845 | **0.850** | +0.5 pp | | **Silhouette** | **0.143** | 0.134 | –0.009 | Now, let's try to interpret the above table- 1. Stronger separation overall. Roll-out simultaneously shrinks within-class spread and pushes class centroids apart, lifting the inter / intra ratio from 1.40 → 1.53. 1. Retrieval matches the improvement. k-NN accuracy ticks up (+0.5 pp), confirming that the tighter, better-separated space helps nearest-neighbour classification. 1. Slight silhouette dip is not a concern. Silhouette drops a hair because it weights both clusters and their nearest neighbours; with tighter clusters the nearest-cluster distance sometimes decreases too. The magnitude (-0.009) is minor compared with the gains in the other metrics. Hence, om CIFAR-10, attention-rollout not only converges faster during training but also yields quantitatively stronger embeddings—better class compactness and better separation—without the variance penalty we saw on CIFAR-100. ### Heatmap Comparison ![heatmap_compare_cifar10](https://hackmd.io/_uploads/HkfIZVCDlg.png) | | Baseline | Roll-out | | ---------------------- | --------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | | **Saliency focus** | High-weight patches are spread across the sky and water as well as the airplane fuselage. | Heat concentrates on the airplane body and wing; background sky / water patches cool down. | | **Background leakage** | Several top-row (sky) and bottom-row (water) cells are warm → background influences the CLS output. | Those same cells turn blue/green → background contribution reduced. | | **Sharpness** | Warm areas bleed into neighbours. | Crisper hot region around the airplane; colder elsewhere. | | Roll-out pooling makes the model rely more on the actual airplane pixels and less on surrounding background, which is consistent with the quantitative gains (tighter intra-class, higher inter-class, +0.5 pp k-NN). ## SVHN Dataset ### Training For SVHN, attached is the training loss curve ![output](https://hackmd.io/_uploads/SJnTXEAwex.png) | Epoch range | Baseline CE (→) | Roll-out CE (→) | Gap (roll-base) @ end | Read | | ------------------ | ------------------: | ------------------: | --------------------: | ----------------------------------------------------- | | **1–10 (warm-up)** | 2.22 → **0.6729** | 2.21 → **0.6447** | **−0.028** | Roll-out gets a small early lead. | | **10–40** | 1.0002 → **0.2595** | 0.9147 → **0.2684** | **+0.009** | Curves nearly overlap; baseline a hair better by E40. | | **40–80** | 0.2595 → **0.1551** | 0.2684 → **0.1574** | **+0.002** | Essentially tied through mid-training. | | **80–120** | 0.1551 → **0.0815** | 0.1574 → **0.0772** | **−0.004** | Roll-out nudges ahead again. | | **120–200 (tail)** | 0.0815 → **0.0296** | 0.0772 → **0.0234** | **−0.006** | Roll-out finishes lower; absolute gap is modest. | On SVHN the two trains almost identically; roll-out is slightly faster early and ends a bit lower, but the differences are small. ### Metrics (10-class, 2600 images probe) | metric | Baseline | Roll-out | Δ | | ----------------------- | --------- | --------- | --------------------------------- | | **Intra-class dist** | **0.562** | 0.602 | ↑ 0.041 (clusters looser) | | **Inter-class dist** | 1.003 | **1.025** | ↑ 0.022 (centroids a bit farther) | | **Inter / Intra ratio** | **1.79** | 1.70 | ↓ 0.09 | | **k-NN @5 (cosine)** | **0.967** | 0.965 | –0.2 pp | | **Silhouette** | **0.350** | 0.310 | ↓ 0.040 | 1. Separation vs. compactness Roll-out nudges centroid distances up, but inflates within-class scatter even more, so the inter / intra ratio drops from 1.79 → 1.70. 2. Retrieval essentially unchanged k-NN@5 slips by only 0.2 pp—a statistical tie. The added variance doesn’t hurt neighbour consistency because SVHN digits are already highly separable. On SVHN, attention-rollout provides no clear advantage. Background is plain and digits are centred, so the CLS token already captures the salient region; patch-weighting adds variance without lifting useful separation. ## Heatmap ![heatmap_compare_svhn](https://hackmd.io/_uploads/rkTBpVAPxe.png) As you can see from above, rollout changes attention distribution only subtly: it smooths-out extreme spikes and lights up the entire digit outline. Because SVHN digits are centred and isolated, this refinement neither helps nor hurts the downstream metrics—exactly what we saw in the metrics result above. ## Iteration 2 and 3 interpretation | data set | visual scene | what roll-out changes | net result | | ------------------------------------------------ | ----------------------------------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- | | **CIFAR-100 – 100 fine-grained natural classes** | cluttered; class-defining cues vary from picture to picture | pushes centroids farther apart **and** inflates the spread inside each class | linear head benefits (lower CE, higher inter-class distance) but k-NN drops because the extra scatter hurts local neighbourhoods | | **CIFAR-10 – 10 coarse natural classes** | single object, limited backgrounds | pushes centroids apart **and** makes clusters slightly tighter | everything improves: lower CE, higher inter / intra ratio, small but noticeable lift in k-NN | | **SVHN – centred digits, uniform background** | very simple | almost no change beyond a tiny increase in variance | training loss difference is negligible; k-NN and centroid metrics are effectively unchanged As for heatmaps, we see that where the background is informative (fruit-on-table, airplane in sky) roll-out concentrates weight on the object and cools the background. In SVHN, where the digit already dominates the frame, both pooling methods highlight the same central region, so the saliency maps, and the metrics — are almost identical. ### Conclusion from the iteration 2 and 3 If the data contain varied scenes in which the class-defining region can be occluded by background (CIFAR-100, CIFAR-10), attention-rollout is worth enabling: it converges faster and gives better linear separability, and on medium-granularity problems like in CIFAR-10 it even helps neighbour retrieval. If the object is always centred and the background is uniform (SVHN-style data), the ordinary [CLS] token already captures everything important, so roll-out adds complexity without delivering a measurable benefit. # Explanation of rollout mechanism used here Figure below shows a contrast between baseline and rollout mechanism, with explanation of rollout to follow. ![IMG_3778](https://hackmd.io/_uploads/HyQbv6cOel.jpg) Before we explain the rollout mechanism, it is important to know what tensors do we get from each layer - 1. Each transformer block we have can optionally return its raw attention weights **w** with shape (B, H, S, S), where B=batch size, H=heads, S=sequence length (1 CLS token + N patch tokens). 2. When we select mode == 'rollout', forward_features asks every block for w and stores them in a list **attn**. If you would see the code for "_rollout" method in ViT class, we basically implement the rollout mechanism from abnar and zueidema paper (2020), and you could infer that we turn per-layer attentions into an “influence” matrix. How? Let's see this: 1. Head Average: For each layer ℓ, average across heads: A_ℓ = mean_over_heads(w_ℓ) → shape (B, S, S). 2. Add residual path or Identity: A_ℓ ← A_ℓ + I This models the residual connection: a token can also “keep” its own content. 3. Row-normalize: A_ℓ[i] = A_ℓ[i] / sum_row(A_ℓ[i]) Now each row is a probability distribution over “where this token routes information next.” 4. Multiply across layers (rollout): Initialize P = I. For each layer in order: P = A_ℓ · P After all layers, P (B, S, S) captures indirect influence paths (e.g., CLS→patch via multiple hops). This is the “rollout” part: it composes the routing over the stack. Now, how is this rollout computation used in forward_features is the question. How it pans out? So, basically what we are doing here is first turning the CLS->Patch influence we got above into pooling weights in the below steps. - Take the CLS row of the final influence matrix and drop the CLS column: W = P[:, 0, 1:] # (B, N_patches) - Normalize across patches: W = W / W.sum(-1, keepdim=True) Once we have gotten these pooling weights, we then produce the representation used by the classifier using below steps: - Get the final-layer patch embeddings patches = x[:, 1:, :] (after the last block + LayerNorm). - Do a weighted sum with the rollout weights: rep = (W.unsqueeze(-1)*patches).sum(1) #shape = (B, embed_dim) - And then we simply feed the "rep" to the classifier instead of the final [CLS] token embedding used in the baseline (i.e., we use rep rather than x[:, 0]). In the code implementation, one important thing to note is that in forward_features method, when collecting attentions, we perform attn.append(w.detach()). That means the pooling weights W do not receive gradients directly (they’re computed from current attentions but treated as constants in backprop). Although gradients still flow through the patch embeddings being pooled, so training differs between modes (CLS vs rollout pooling), but there’s no extra loss or gradient pushing attentions to change specifically for rollout. It’s a pure readout change, not an auxiliary objective (auxiliary loss function). Detaching keeps rollout strictly readout-only, ie the model still learns via cross-entropy on the pooled representation, but we avoid long gradient paths through the product of per-layer attention matrices. If .detach() were removed, gradients would flow through that chain, implicitly pressuring attention patterns, which can destabilize training and muddy a clean apples-to-apples comparison with the CLS baseline.