<style> p { text-align: left; } .reveal h5 { font-weight: normal; } div.alert { padding: 1px 25px; } .reveal { font-size: 30px; } .reveal section img { border: none; box-shadow: none; } .reveal pre { box-shadow: none; width: 100%; } .reveal pre code { font-size: 0.9em; line-height: 1.4em; padding: 20px; } table { font-size: 0.8em !important; } </style> <!-- Some image --> <img src='https://docs.flyte.org/en/latest/_static/flyte_circle_gradient_1_4x4.png' width="150px" style="border: none; box-shadow: none;"/> <br/> ## Training Parallelized, GPU-accelerated Pytorch Models on Flyte\* <br/> #### Flyte OSS Sync 09/21/2021 <br/> ###### \*Single-Node Training <!-- Self reference link --> ###### https://hackmd.io/@nielsbantilan/flyte-oss-20210921 --- ## Outline <!-- Some structure --> - Model Development Lifecycle ♻️ - Scaling a Model 🐜 📏 🏔 - Pytorch Parallelism 🔥 🔥 - Demo ✨ 💻 ✨ - Takeaways 🎁 🛍 - Next Steps 👟 👟 --- ## Model Development Lifecycle ♻️ _By "model development" I mean the process of training a model in the lab, setting aside dataset collection and production deployment for the moment._ :::info **Assumption**: You have a fairly high-quality dataset, e.g. in the supervised setting, this means that you're pretty confident that your labels appropriately capture the concepts you're trying to model ::: --- ### The Bias Variance Trade-off - **Bias**: how close a model's predictions are to the ground truth. - **Variance**: how stable a model's predictions are in the face of perturbations in the data. <img src="https://www.educative.io/api/edpresso/shot/6668977167138816/image/5033807687188480" width="70%"> Source: https://www.educative.io/edpresso/overfitting-and-underfitting <!-- .element: style="font-size: 14px; text-align: center;" --> --- ### The Bias Variance Trade-off There's a theoretical "sweet spot" that minimizes both bias and variance. <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/9/9f/Bias_and_variance_contributing_to_total_error.svg/1200px-Bias_and_variance_contributing_to_total_error.svg.png" alt="Bias and variance contributing to total error.svg" width="70%"> Source: https://en.wikipedia.org/wiki/Bias%E2%80%93variance_tradeoff <!-- .element: style="font-size: 14px; text-align: center;" --> --- ### The Bias Variance Trade-off #### In practice, you assess overfitting by: - splitting your data into a training set and a test set - iteratively updating the model on the training set - plotting training set loss and test set loss over time <img src="https://i.stack.imgur.com/m0NAK.png" width="40%"> Source: https://stackoverflow.com/questions/44909134/how-to-avoid-overfitting-on-a-simple-feed-forward-network <!-- .element: style="font-size: 14px; text-align: center;" --> --- #### Phase 1: Small Model, Small Data ```mermaid flowchart LR X[Start] --> A A(Train Model) --> B{Overfit?}; B --> |No| C(Fix implementation) C --> A B --> |Yes| Y[Phase 2] style X fill:#fff2b2,stroke:#333 style Y fill:#fff2b2,stroke:#333 ``` ##### Can my model overfit a small subset of the data distribution? <!-- .element: class="fragment" --> --- #### Phase 2: Medium Model, Full Dataset ```mermaid flowchart LR X[Phase 1] --> D(Train Model) D --> E{Overfit?} E --> |Yes| F(Try Smaller Model, Regularization) F --> D E --> |No| Y[Phase 3] style X fill:#fff2b2,stroke:#333 style Y fill:#fff2b2,stroke:#333 ``` ##### Can my model fit the training set without overfitting on the test set? <!-- .element: class="fragment" --> --- #### Phase 3: Large Model, Full Dataset ```mermaid flowchart LR X[Phase 3] --> G(Train model) G --> H{Overfit?} H --> |Yes| F(Try Smaller Model, Regularization) F --> G H --> |No| I(Scale Up) I --> J(Tune Hyperparameters) I --> K(Increase Model Size) style X fill:#fff2b2,stroke:#333,stroke-width:4px style I fill:#80fcdf,stroke:#333,stroke-width:4px style J fill:#80fcdf,stroke:#333,stroke-dasharray: 5 5 style K fill:#80fcdf,stroke:#333,stroke-width:4px ``` ##### Can I improve the performance of my model with some set of hyperparameters? <!-- .element: class="fragment" --> ##### **Can I improve the performance by training a larger model?** <!-- .element: class="fragment" --> --- ## Scaling a model 🐜 📏 🏔 Note: In this workshop I'll do a broad introduction to two forms of parallelism: data parallelism, and model parallelism. --- ### Data Parallelism ```mermaid flowchart LR subgraph Model Replicas M0[Model 0] M1[Model 1] M2[Model 2] M3[Model 3] end D0[(Data batch 0)] --> M0 D1[(Data batch 1)] --> M1 D2[(Data batch 2)] --> M2 D3[(Data batch 3)] --> M3 M0 --> G0([Gradients 0]) M1 --> G1([Gradients 1]) M2 --> G2([Gradients 2]) M3 --> G3([Gradients 3]) G0 --> S[Combine] G1 --> S G2 --> S G3 --> S S --update--> UM0[Model 0] style M0 fill:#80fcdf,stroke:#333,stroke-width:2px style UM0 fill:#80fcdf,stroke:#333,stroke-width:2px ``` Note: - When to use? Model fits on a single GPU and can process at least a single data point. - Make multiple copies of a model on each GPU - Each one processes some batch of data - Gradients are accumulated - Gradients updates applied to weights in the host machine - Weights synced up across all machines --- ### Model Parallelism ```mermaid flowchart LR subgraph Model M0[Weights 0] M1[Weights 1] M2[Weights 2] M3[Weights 3] end D0[(Data Batch)] --> M0 M0 --> M1 M1 --> M2 M2 --> M3 M3 --> G([Gradients]) G --update--> M[Model] style M fill:#feffdd,stroke:#787a0b ``` Note: - When to use? - Model doesn't fit into a single GPU - Model on a single GPU can't process a single data point. - Split up the weights across multiple GPUs - A single batch of data is processed sequentially according to model architecture - Gradient updates are applied to the weights across the GPUs --- ## Pytorch Parallelism 🔥 🔥 Anatomy of a pytorch training script: ```python import torch import torch.nn.functional as F from torch.optim import SGD from torch.utils.data import Dataset, DataLoader class Model(torch.nn.Module): ... # 1. define model dataloader = DataLoader(Dataset(...)) # 2. load data model = Model() # 3. initialize model opt = SGD(model.parameters(), ...) # 4. specify optimizer for epoch in range(n_epochs): # 5. for some number of epochs for (X, y) in dataloader: # 6. iterate through data loss = F.cross_entropy(model(X), y) # 7. compute loss loss.backward() # 8. backpropagate errors opt.step() # 9. gradient update opt.zero_grad() ... # 10. collect metrics torch.save(model.state_dict(), "model.pt") # 11. save model ``` --- ### `torch.nn.DataParallel` ```python # suppose we have 4 GPUs model = torch.nn.DataParallel(module=Model(), device_ids=[0, 1, 2, 3]) # ✨ and you're done! ✨ for (X, y) in dataloader: output = model(X) ... # make sure to access the module attribute to save weights torch.save(model.module.state_dict(), "model.pt") ``` ##### `DataParallel` splits the batch across all specified GPU devices <!-- .element: class="fragment" --> ##### Performs the forward pass on each GPU <!-- .element: class="fragment" --> ##### Concatenates the outputs into a single tensor <!-- .element: class="fragment" --> ##### Gradients from each device are summed into the original module <!-- .element: class="fragment" --> --- ### `torch.nn.parallel.DistributedDataParallel` ```python model = torch.nn.parallel.DistributedDataParallel(Model(), rank=rank) # but there's more to do 😅 ``` ##### Spawn `N` number of processes for each GPU from rank `0` to `N - 1` using `torch.multiprocessing` <!-- .element: class="fragment" --> ##### Each process needs to invoke `torch.distributed.init_process_group` before running the training routine <!-- .element: class="fragment" --> ##### Write code carefully, being aware that `rank == 0` is the host process <!-- .element: class="fragment" --> --- ### `torch.nn.parallel.DistributedDataParallel` ##### Spawn `N` number of processes ```python # suppose we have 4 GPUs WORLD_SIZE = 4 import torch.multiprocessing torch.multiprocessing.spawn( trainer, args=(WORLD_SIZE, ), nprocs=WORLD_SIZE, # number of processes join=True, # blocking join on all processes ) ``` ##### What's `trainer`? <!-- .element: class="fragment" --> --- ### `torch.nn.parallel.DistributedDataParallel` ##### Spawn `N` number of processesEach process needs to invoke `torch.distributed.init_process_group` ```python import torch.distributed def trainer(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) # define model model = torch.nn.parallel.DistributedDataParallel(Model(), rank=rank) # and so on... for (X, y) in dataloader: output = model(X) ... ``` ##### What are some gotchas? <!-- .element: class="fragment" --> --- ### `torch.nn.parallel.DistributedDataParallel` ##### Write code carefully, being aware that `rank == 0` is the host process ```python def trainer(rank, world_size): ... # log metrics in the main process for (X, y) in dataloader: output = model(X) loss = ... if rank == 0: logger.info(f"training loss: {loss}") # only save model if rank == 0: torch.save(model.state_dict(), "model.pt") ``` --- ### Pytorch Data Parallelism Summary **`DataParallel`** : model wrapper to easily achieve data parallelism <!-- .element: class="fragment" --> - **pro**: one-liner change - **con**: performance overhead: uses multithreading, subject to GIL <!-- .element: class="fragment" --> **`DistributedDataParallel`** : model wrapper with additional script modifications <!-- .element: class="fragment" --> - **pro**: uses multiprocessing, each GPU has its own process => better performance - **cons**: additional setup code required <!-- .element: class="fragment" --> --- ### Things to think about: **Model:** <!-- .element: class="fragment" --> - Dropout consumes more model memory - Model weight float : more precision -> more memory - When performing operations only on the host process, use `torch.distributed.barrier()` to block other processes from continuing execution. <!-- .element: class="fragment" --> **Data:** <!-- .element: class="fragment" --> - Download your dataset on `rank 0` first to avoid race conditions. - What batch size fits memory? <!-- .element: class="fragment" --> **Metrics:** <!-- .element: class="fragment" --> - In general, collect metrics on `rank 0` unless you need process-specific metrics <!-- .element: class="fragment" --> --- ## Demo ✨ 💻 ✨ #### Training on the MNIST dataset <img src="https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png" width="70%"> Source: https://en.wikipedia.org/wiki/MNIST_database <!-- .element: style="font-size: 14px; text-align: center;" --> Note: - Flytesnacks MNIST example - Single GPU example - Use Weights and Biases to track metrics - Multi GPU example - Download dataset on `rank 0` first to avoid race conditions - Only `wandb.init()` in the `rank 0` process. - Showcase some basic profiling of models - Utilization as a function of batch size, model size, etc. --- ## Takeaways 🎁 🛍 ##### Start small to ensure correct model implementation and characterize convergence/overfitting dynamics <!-- .element: class="fragment" --> ##### Scale up to reduce iteration time with GPU accelerators <!-- .element: class="fragment" --> ##### Before scaling up, _know why you're scaling up!_ <!-- .element: class="fragment" --> Note: - My model can overfit small subset of data - A small model is underfitting on whole dataset - Therefore, a larger model is likely to fit the data better, but we need to tune it so we make sure it doesn't overfit. --- ## Next Steps 👟 👟 ##### Horovod MPI Operator for Multi-node Training ##### Improve cost-effectiveness with interruptible instances --- ## Announcing ##### Flyte OSS Sync ⚡️ Lightning Talks! ⚡️ ##### https://forms.gle/ZTWGeUCUyY9u2rjy9
{"metaMigratedAt":"2023-06-16T10:39:53.468Z","metaMigratedFrom":"YAML","title":"OSS Sync - Training Parallelized, GPU-accelerated Pytorch Models on Flyte","breaks":true,"description":"View the slide with \"Slide Mode\".","slideOptions":"{\"theme\":\"white\",\"fragments\":true,\"spotlight\":{\"enabled\":false}}","contributors":"[{\"id\":\"e8a53bc9-5dce-45e1-8810-ea537e4b141b\",\"add\":19368,\"del\":7542}]"}
    690 views