# 🌟 **Vision Transformer (ViT) Tutorial – Part 3: Pretraining, Transfer Learning & Real-World Applications** **#VisionTransformer #TransferLearning #HuggingFace #ImageNet #FineTuning #AI #DeepLearning #ComputerVision #Transformers #ModelZoo** --- ## πŸ”Ή **Table of Contents** 1. [Recap of Part 2](#recap-of-part-2) 2. [Why Pretraining Matters: The Power of Scale](#why-pretraining-matters-the-power-of-scale) 3. [Pretrained ViT Models: ViT-Base, ViT-Large, ViT-Huge](#pretrained-vit-models-vit-base-vit-large-vit-huge) 4. [Using Hugging Face Transformers for ViT](#using-hugging-face-transformers-for-vit) 5. [Loading Pretrained ViT from Model Zoo](#loading-pretrained-vit-from-model-zoo) 6. [Transfer Learning: Adapting ViT to Custom Datasets](#transfer-learning-adapting-vit-to-custom-datasets) 7. [Fine-Tuning Strategies: Full, Partial, and Feature Extraction](#fine-tuning-strategies-full-partial-and-feature-extraction) 8. [Case Study: Fine-Tuning ViT on CIFAR-100](#case-study-fine-tuning-vit-on-cifar-100) 9. [Visualizing Attention Rollout & Token Merging](#visualizing-attention-rollout--token-merging) 10. [Comparing ViT, DeiT, and Hybrid Models](#comparing-vit-deit-and-hybrid-models) 11. [Optimizing ViT for Inference Speed](#optimizing-vit-for-inference-speed) 12. [Common Pitfalls in Transfer Learning](#common-pitfalls-in-transfer-learning) 13. [Visualizing Transfer Learning Pipeline (Diagram)](#visualizing-transfer-learning-pipeline-diagram) 14. [Summary & What’s Next in Part 4](#summary--whats-next-in-part-4) --- ## πŸ” **1. Recap of Part 2** In **Part 2**, we: - Built a **Vision Transformer from scratch** in PyTorch. - Implemented **patch embedding**, **multi-head attention**, and **transformer blocks**. - Trained a small ViT on **CIFAR-10**. - Learned that **ViT underperforms CNNs on small datasets** without pretraining. - Visualized training dynamics and debugged common issues. Now, in **Part 3**, we unlock ViT’s true potential: **pretraining at scale** and **transfer learning**. You’ll learn how to: - Use **pretrained ViT models** from Hugging Face. - **Fine-tune** ViT on custom datasets. - Visualize **attention rollout**. - Optimize for **speed and efficiency**. Let’s go! --- ## πŸš€ **2. Why Pretraining Matters: The Power of Scale** In **Part 2**, our ViT only reached ~75% on CIFAR-10 β€” far below ResNet’s ~95%. But in the original ViT paper, **ViT-Huge achieved 78.5% on ImageNet** β€” and **outperformed CNNs** when pretrained on **JFT-300M** (300 million images). > πŸ’‘ **Key Insight**: > **ViT needs large-scale pretraining to unlock its capacity.** ### πŸ“ˆ Scaling Laws: Data vs Performance ![Scaling Laws](https://aicompetence.org/wp-content/uploads/2025/05/654lmk.webp) *(Image: ViT scales better with data than CNNs β€” performance grows linearly with dataset size)* > βœ… ViT is **data-hungry** but **highly scalable**. This is why **transfer learning** is essential. --- ## πŸ—οΈ **3. Pretrained ViT Models: ViT-Base, ViT-Large, ViT-Huge** Google released several ViT variants pretrained on **ImageNet-21k** and **ImageNet-1k**. | Model | Patch Size | Image Size | Params | Top-1 Acc (ImageNet) | |------|-----------|------------|--------|------------------------| | **ViT-Base/16** | 16x16 | 224x224 | 86M | 77.9% | | **ViT-Large/16** | 16x16 | 224x224 | 307M | 76.5% | | **ViT-Huge/14** | 14x14 | 224x224 | 632M | 78.5% | > βœ… **ViT-Base/16** is the most commonly used. They are available via: - **Hugging Face Hub** - **Google Research GitHub** - **TorchVision (newer versions)** --- ## πŸ“¦ **4. Using Hugging Face Transformers for ViT** [Hugging Face](https://huggingface.co) provides a unified API for ViT. ### βœ… Install ```bash pip install transformers torch torchvision ``` ### βœ… Load Pretrained ViT ```python from transformers import ViTImageProcessor, ViTForImageClassification import torch # Load processor (handles preprocessing) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') # Load model model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') ``` > βœ… Automatically downloads weights and config. --- ### βœ… Inference on a Single Image ```python from PIL import Image import requests # Load image url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) # Preprocess inputs = processor(images=image, return_tensors="pt") # Predict with torch.no_grad(): logits = model(**inputs).logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) ``` > βœ… Output: `"cat"` or `"Egyptian cat"` --- ## 🧩 **5. Loading Pretrained ViT from Model Zoo** You can also use **TorchVision** (if available): ```python import torchvision.models as models # Only available in newer versions # vit = models.vit_b_16(pretrained=True) # TorchVision 0.15+ ``` Or load from **Google’s official checkpoints**: ```python # Using timm (another popular library) import timm model = timm.create_model('vit_base_patch16_224', pretrained=True) ``` > βœ… `timm` supports 100+ ViT variants. Install with: ```bash pip install timm ``` --- ## πŸ” **6. Transfer Learning: Adapting ViT to Custom Datasets** Transfer learning means: 1. Start with a **pretrained ViT** (trained on ImageNet). 2. Replace the final classification head. 3. **Fine-tune** on your dataset. ### βœ… Use Case: Medical Image Classification You have 5,000 X-ray images (Pneumonia vs Normal). You don’t have enough data to train ViT from scratch β€” but you can **fine-tune a pretrained ViT**. --- ## πŸ› οΈ **7. Fine-Tuning Strategies: Full, Partial, and Feature Extraction** ### βœ… **Strategy 1: Full Fine-Tuning** Update **all layers**. ```python model = ViTForImageClassification.from_pretrained( 'google/vit-base-patch16-224', num_labels=100, # e.g., CIFAR-100 ignore_mismatched_sizes=True ) # Unfreeze all parameters for param in model.parameters(): param.requires_grad = True ``` > βœ… Best performance, but slow and needs lots of data. --- ### βœ… **Strategy 2: Partial Fine-Tuning** Only fine-tune the **last few layers**. ```python # Freeze all for param in model.parameters(): param.requires_grad = False # Unfreeze last transformer block + head for param in model.vit.encoder.layer[-2:].parameters(): param.requires_grad = True for param in model.classifier.parameters(): param.requires_grad = True ``` > βœ… Faster, less overfitting. --- ### βœ… **Strategy 3: Feature Extraction** Use ViT as a **fixed feature extractor**. ```python # Remove classifier model.classifier = nn.Identity() # Forward pass to extract features with torch.no_grad(): features = model(**inputs).logits # (1, 768) # Train a small classifier on top clf = nn.Linear(768, num_classes) ``` > βœ… Fastest, but lower accuracy. --- ## πŸ§ͺ **8. Case Study: Fine-Tuning ViT on CIFAR-100** Let’s fine-tune **ViT-Base** on **CIFAR-100** (100 classes, 32x32 images). ### βœ… Problem: ViT expects 224x224 We must **resize images**. ```python from transformers import ViTFeatureExtractor feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') # Custom transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std), ]) ``` --- ### βœ… Load CIFAR-100 ```python trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=16, shuffle=True) # Small batch due to memory testloader = DataLoader(testset, batch_size=16) ``` --- ### βœ… Initialize Model ```python model = ViTForImageClassification.from_pretrained( 'google/vit-base-patch16-224', num_labels=100, ignore_mismatched_sizes=True ).to(device) # Only fine-tune last 4 layers for param in model.vit.parameters(): param.requires_grad = False for param in model.vit.encoder.layer[-4:].parameters(): param.requires_grad = True ``` --- ### βœ… Training Loop (Simplified) ```python optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5) criterion = nn.CrossEntropyLoss() for epoch in range(10): model.train() for inputs, labels in trainloader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs, labels=labels) loss = outputs.loss loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") ``` > βœ… After 10 epochs: ~85% accuracy (vs ~50% from scratch). --- ## πŸ” **9. Visualizing Attention Rollout & Token Merging** ### πŸ”Ή **Attention Rollout** Shows how attention spreads across the image through layers. Uses the idea: > "If token A attends to token B, it 'inherits' B’s attention." Recursive formula: $$ R = \text{Identity} + \sum_{l=1}^L A^l $$ where $A^l$ is the attention matrix at layer $l$. --- ### βœ… Code (Simplified) ```python def rollout(attentions, head_fusion="mean"): result = torch.eye(attentions[0].size(-1)) with torch.no_grad(): for attn in attentions: if head_fusion == "mean": attn_mean = attn.mean(1) else: attn_mean = attn.sum(1) / attn.sum(1).sum(-1) result = torch.matmul(attn_mean, result) return result ``` --- ### πŸ–ΌοΈ Attention Rollout Example ![Attention Rollout](https://www.researchgate.net/publication/324055383/figure/fig4/AS:675263335657482@1538006733951/Qualitative-results-of-attention-transition-We-visualize-the-predicted-heatmap-on-the.png) *(Image: Heatmap showing attention focused on object regions)* > βœ… The model learns to **attend to relevant parts** like eyes, wheels, or wings. --- ## πŸ” **10. Comparing ViT, DeiT, and Hybrid Models** | Model | Key Idea | Advantage | Use Case | |------|---------|----------|---------| | **ViT** | Pure transformer | Global context | Large datasets | | **DeiT** | **D**ata-**e**fficient **I**mage **T**ransformer | Trains on ImageNet without extra data | Medium datasets | | **Hybrid (e.g., BoTNet)** | CNN + Transformer | Local + global | Object detection | | **MobileViT** | Lightweight ViT | Fast on mobile | Edge devices | | **Twins SVT** | Spatial attention | Faster inference | Real-time apps | > βœ… **DeiT** uses **token distillation** to match teacher model. --- ### πŸ“Š Performance Comparison (ImageNet) ![ViT vs DeiT vs CNN](https://miro.medium.com/v2/resize:fit:1232/1*2CW60TnErZ8-zyF-iKKl5Q.png) *(Image: DeiT matches ViT with less data)* --- ## ⚑ **11. Optimizing ViT for Inference Speed** ViT is **computationally heavy** due to self-attention: $$ \text{Complexity} = O(N^2 \cdot D) $$ where $N$ = number of patches, $D$ = embedding size. ### βœ… Optimization Techniques | Technique | How It Helps | |---------|-------------| | **Model Pruning** | Remove unimportant attention heads | | **Quantization** | Convert weights to FP16 or INT8 | | **Knowledge Distillation** | Train small student from large teacher | | **Patch Merging** | Reduce $N$ in deeper layers | | **Efficient Attention** | Use Linformer, Performer, or FlashAttention | --- ### βœ… Example: FP16 Inference ```python model.half() # Convert to float16 inputs = inputs.half() with torch.no_grad(): logits = model(inputs).logits ``` > βœ… 2x faster, 50% memory. --- ## ⚠️ **12. Common Pitfalls in Transfer Learning** ### ❌ **Pitfall 1: Not Resizing Images** ViT expects **224x224**. Feeding 32x32 β†’ blurry, poor performance. βœ… Always **resize or crop**. --- ### ❌ **Pitfall 2: Using Wrong Normalization** ImageNet stats: `mean=[0.485, 0.456, 0.406]`, `std=[0.229, 0.224, 0.225]` Using CIFAR stats β†’ poor convergence. βœ… Use `ViTFeatureExtractor` for correct values. --- ### ❌ **Pitfall 3: High Learning Rate** Pretrained models are sensitive. βœ… Use **low LR** (1e-5 to 5e-5). --- ### ❌ **Pitfall 4: Not Freezing Early Layers** Fine-tuning all layers on small data β†’ overfitting. βœ… Freeze early layers, fine-tune last few. --- ## πŸ–ΌοΈ **13. Visualizing Transfer Learning Pipeline (Diagram)** ![Transfer Learning Pipeline](https://production-media.paperswithcode.com/social-images/UhPqfdxgjZGSAsbC.png) ``` Pretrained ViT (ImageNet) ↓ Remove Classifier Head ↓ Add New Head (e.g., 100 classes) ↓ Freeze Early Layers ↓ Fine-Tune on Custom Dataset ↓ Optimized for Inference ``` > πŸ” This is how ViT powers real-world applications. --- ## 🏁 **14. Summary & What’s Next in Part 4** ### βœ… **What You’ve Learned in Part 3** - Why **pretraining** is essential for ViT. - How to load **pretrained ViT** from Hugging Face. - **Transfer learning** strategies: full, partial, feature extraction. - Fine-tuned ViT on **CIFAR-100** with resizing. - Visualized **attention rollout**. - Compared **ViT, DeiT, and hybrid models**. - Optimized for **speed and efficiency**. --- ### πŸ”œ **What’s Coming in Part 4: Vision Transformers for Object Detection, Segmentation & Video** In the next part, we’ll explore: - πŸ–ΌοΈ **DETR**: Transformer for **object detection**. - 🎨 **Segmenter**: ViT for **semantic segmentation**. - πŸŽ₯ **Video Swin Transformer**: For **video classification**. - πŸ”„ **MAE (Masked Autoencoder)**: Self-supervised pretraining. - 🧩 **Multimodal Models**: CLIP, Flamingo. - πŸ§ͺ **Training ViT from Scratch with MAE**. > πŸ“Œ **#DETR #Segmenter #VideoTransformer #MAE #SelfSupervised #Multimodal** --- ## πŸ™Œ Final Words You’ve now mastered **real-world Vision Transformer applications**. > πŸ’¬ **"Pretraining is not a shortcut β€” it’s a paradigm shift. ViT learns general visual understanding, then specializes."** In **Part 4**, we’ll go beyond classification and explore how Transformers are revolutionizing **detection, segmentation, and video**. --- πŸ“Œ **Pro Tip**: Always check **Hugging Face Model Hub** before training from scratch. πŸ” **Share this guide** to help others leverage **pretrained vision models**. --- βœ… **You're now ready for Part 4!** We're entering the world of **Transformers beyond classification**. #VisionTransformer #TransferLearning #HuggingFace #FineTuning #DeepLearning #AI #ComputerVision #Transformers #ModelZoo #AttentionIsAllYouNeed