<style>
p {
text-align: left;
}
.reveal h5 {
font-weight: normal;
}
div.alert {
padding: 1px 25px;
}
.reveal {
font-size: 30px;
}
.reveal section img {
border: none;
box-shadow: none;
}
.reveal pre {
box-shadow: none;
width: 100%;
}
.reveal pre code {
font-size: 0.9em;
line-height: 1.4em;
padding: 20px;
}
table {
font-size: 0.8em !important;
}
</style>
<!-- Some image -->
<img src='https://docs.flyte.org/en/latest/_static/flyte_circle_gradient_1_4x4.png' width="150px" style="border: none; box-shadow: none;"/>
<br/>
## Training Parallelized, GPU-accelerated Pytorch Models on Flyte\*
<br/>
#### Flyte OSS Sync 09/21/2021
<br/>
###### \*Single-Node Training
<!-- Self reference link -->
###### https://hackmd.io/@nielsbantilan/flyte-oss-20210921
---
## Outline
<!-- Some structure -->
- Model Development Lifecycle ♻️
- Scaling a Model 🐜 📏 🏔
- Pytorch Parallelism 🔥 🔥
- Demo ✨ 💻 ✨
- Takeaways 🎁 🛍
- Next Steps 👟 👟
---
## Model Development Lifecycle ♻️
_By "model development" I mean the process of training a model in the lab, setting aside dataset collection and production deployment for the moment._
:::info
**Assumption**: You have a fairly high-quality dataset, e.g. in the supervised setting, this means that you're pretty confident that your labels appropriately capture the concepts you're trying to model
:::
---
### The Bias Variance Trade-off
- **Bias**: how close a model's predictions are to the ground truth.
- **Variance**: how stable a model's predictions are in the face of perturbations in the data.
<img src="https://www.educative.io/api/edpresso/shot/6668977167138816/image/5033807687188480" width="70%">
Source: https://www.educative.io/edpresso/overfitting-and-underfitting
<!-- .element: style="font-size: 14px; text-align: center;" -->
---
### The Bias Variance Trade-off
There's a theoretical "sweet spot" that minimizes both bias and variance.
<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/9/9f/Bias_and_variance_contributing_to_total_error.svg/1200px-Bias_and_variance_contributing_to_total_error.svg.png" alt="Bias and variance contributing to total error.svg" width="70%">
Source: https://en.wikipedia.org/wiki/Bias%E2%80%93variance_tradeoff
<!-- .element: style="font-size: 14px; text-align: center;" -->
---
### The Bias Variance Trade-off
#### In practice, you assess overfitting by:
- splitting your data into a training set and a test set
- iteratively updating the model on the training set
- plotting training set loss and test set loss over time
<img src="https://i.stack.imgur.com/m0NAK.png" width="40%">
Source: https://stackoverflow.com/questions/44909134/how-to-avoid-overfitting-on-a-simple-feed-forward-network
<!-- .element: style="font-size: 14px; text-align: center;" -->
---
#### Phase 1: Small Model, Small Data
```mermaid
flowchart LR
X[Start] --> A
A(Train Model) --> B{Overfit?};
B --> |No| C(Fix implementation)
C --> A
B --> |Yes| Y[Phase 2]
style X fill:#fff2b2,stroke:#333
style Y fill:#fff2b2,stroke:#333
```
##### Can my model overfit a small subset of the data distribution?
<!-- .element: class="fragment" -->
---
#### Phase 2: Medium Model, Full Dataset
```mermaid
flowchart LR
X[Phase 1] --> D(Train Model)
D --> E{Overfit?}
E --> |Yes| F(Try Smaller Model, Regularization)
F --> D
E --> |No| Y[Phase 3]
style X fill:#fff2b2,stroke:#333
style Y fill:#fff2b2,stroke:#333
```
##### Can my model fit the training set without overfitting on the test set?
<!-- .element: class="fragment" -->
---
#### Phase 3: Large Model, Full Dataset
```mermaid
flowchart LR
X[Phase 3] --> G(Train model)
G --> H{Overfit?}
H --> |Yes| F(Try Smaller Model, Regularization)
F --> G
H --> |No| I(Scale Up)
I --> J(Tune Hyperparameters)
I --> K(Increase Model Size)
style X fill:#fff2b2,stroke:#333,stroke-width:4px
style I fill:#80fcdf,stroke:#333,stroke-width:4px
style J fill:#80fcdf,stroke:#333,stroke-dasharray: 5 5
style K fill:#80fcdf,stroke:#333,stroke-width:4px
```
##### Can I improve the performance of my model with some set of hyperparameters?
<!-- .element: class="fragment" -->
##### **Can I improve the performance by training a larger model?**
<!-- .element: class="fragment" -->
---
## Scaling a model 🐜 📏 🏔
Note:
In this workshop I'll do a broad introduction to two forms of parallelism: data parallelism, and model parallelism.
---
### Data Parallelism
```mermaid
flowchart LR
subgraph Model Replicas
M0[Model 0]
M1[Model 1]
M2[Model 2]
M3[Model 3]
end
D0[(Data batch 0)] --> M0
D1[(Data batch 1)] --> M1
D2[(Data batch 2)] --> M2
D3[(Data batch 3)] --> M3
M0 --> G0([Gradients 0])
M1 --> G1([Gradients 1])
M2 --> G2([Gradients 2])
M3 --> G3([Gradients 3])
G0 --> S[Combine]
G1 --> S
G2 --> S
G3 --> S
S --update--> UM0[Model 0]
style M0 fill:#80fcdf,stroke:#333,stroke-width:2px
style UM0 fill:#80fcdf,stroke:#333,stroke-width:2px
```
Note:
- When to use? Model fits on a single GPU and can process at least a single data point.
- Make multiple copies of a model on each GPU
- Each one processes some batch of data
- Gradients are accumulated
- Gradients updates applied to weights in the host machine
- Weights synced up across all machines
---
### Model Parallelism
```mermaid
flowchart LR
subgraph Model
M0[Weights 0]
M1[Weights 1]
M2[Weights 2]
M3[Weights 3]
end
D0[(Data Batch)] --> M0
M0 --> M1
M1 --> M2
M2 --> M3
M3 --> G([Gradients])
G --update--> M[Model]
style M fill:#feffdd,stroke:#787a0b
```
Note:
- When to use?
- Model doesn't fit into a single GPU
- Model on a single GPU can't process a single data point.
- Split up the weights across multiple GPUs
- A single batch of data is processed sequentially according to model architecture
- Gradient updates are applied to the weights across the GPUs
---
## Pytorch Parallelism 🔥 🔥
Anatomy of a pytorch training script:
```python
import torch
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader
class Model(torch.nn.Module): ... # 1. define model
dataloader = DataLoader(Dataset(...)) # 2. load data
model = Model() # 3. initialize model
opt = SGD(model.parameters(), ...) # 4. specify optimizer
for epoch in range(n_epochs): # 5. for some number of epochs
for (X, y) in dataloader: # 6. iterate through data
loss = F.cross_entropy(model(X), y) # 7. compute loss
loss.backward() # 8. backpropagate errors
opt.step() # 9. gradient update
opt.zero_grad()
... # 10. collect metrics
torch.save(model.state_dict(), "model.pt") # 11. save model
```
---
### `torch.nn.DataParallel`
```python
# suppose we have 4 GPUs
model = torch.nn.DataParallel(module=Model(), device_ids=[0, 1, 2, 3])
# ✨ and you're done! ✨
for (X, y) in dataloader:
output = model(X)
...
# make sure to access the module attribute to save weights
torch.save(model.module.state_dict(), "model.pt")
```
##### `DataParallel` splits the batch across all specified GPU devices
<!-- .element: class="fragment" -->
##### Performs the forward pass on each GPU
<!-- .element: class="fragment" -->
##### Concatenates the outputs into a single tensor
<!-- .element: class="fragment" -->
##### Gradients from each device are summed into the original module
<!-- .element: class="fragment" -->
---
### `torch.nn.parallel.DistributedDataParallel`
```python
model = torch.nn.parallel.DistributedDataParallel(Model(), rank=rank)
# but there's more to do 😅
```
##### Spawn `N` number of processes for each GPU from rank `0` to `N - 1` using `torch.multiprocessing`
<!-- .element: class="fragment" -->
##### Each process needs to invoke `torch.distributed.init_process_group` before running the training routine
<!-- .element: class="fragment" -->
##### Write code carefully, being aware that `rank == 0` is the host process
<!-- .element: class="fragment" -->
---
### `torch.nn.parallel.DistributedDataParallel`
##### Spawn `N` number of processes
```python
# suppose we have 4 GPUs
WORLD_SIZE = 4
import torch.multiprocessing
torch.multiprocessing.spawn(
trainer,
args=(WORLD_SIZE, ),
nprocs=WORLD_SIZE, # number of processes
join=True, # blocking join on all processes
)
```
##### What's `trainer`?
<!-- .element: class="fragment" -->
---
### `torch.nn.parallel.DistributedDataParallel`
##### Spawn `N` number of processesEach process needs to invoke `torch.distributed.init_process_group`
```python
import torch.distributed
def trainer(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
# define model
model = torch.nn.parallel.DistributedDataParallel(Model(), rank=rank)
# and so on...
for (X, y) in dataloader:
output = model(X)
...
```
##### What are some gotchas?
<!-- .element: class="fragment" -->
---
### `torch.nn.parallel.DistributedDataParallel`
##### Write code carefully, being aware that `rank == 0` is the host process
```python
def trainer(rank, world_size):
...
# log metrics in the main process
for (X, y) in dataloader:
output = model(X)
loss = ...
if rank == 0:
logger.info(f"training loss: {loss}")
# only save model
if rank == 0:
torch.save(model.state_dict(), "model.pt")
```
---
### Pytorch Data Parallelism Summary
**`DataParallel`** : model wrapper to easily achieve data parallelism
<!-- .element: class="fragment" -->
- **pro**: one-liner change
- **con**: performance overhead: uses multithreading, subject to GIL
<!-- .element: class="fragment" -->
**`DistributedDataParallel`** : model wrapper with additional script modifications
<!-- .element: class="fragment" -->
- **pro**: uses multiprocessing, each GPU has its own process => better performance
- **cons**: additional setup code required
<!-- .element: class="fragment" -->
---
### Things to think about:
**Model:**
<!-- .element: class="fragment" -->
- Dropout consumes more model memory
- Model weight float : more precision -> more memory
- When performing operations only on the host process, use `torch.distributed.barrier()` to block other processes from continuing execution.
<!-- .element: class="fragment" -->
**Data:**
<!-- .element: class="fragment" -->
- Download your dataset on `rank 0` first to avoid race conditions.
- What batch size fits memory?
<!-- .element: class="fragment" -->
**Metrics:**
<!-- .element: class="fragment" -->
- In general, collect metrics on `rank 0` unless you need process-specific metrics
<!-- .element: class="fragment" -->
---
## Demo ✨ 💻 ✨
#### Training on the MNIST dataset
<img src="https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png" width="70%">
Source: https://en.wikipedia.org/wiki/MNIST_database
<!-- .element: style="font-size: 14px; text-align: center;" -->
Note:
- Flytesnacks MNIST example
- Single GPU example
- Use Weights and Biases to track metrics
- Multi GPU example
- Download dataset on `rank 0` first to avoid race conditions
- Only `wandb.init()` in the `rank 0` process.
- Showcase some basic profiling of models
- Utilization as a function of batch size, model size, etc.
---
## Takeaways 🎁 🛍
##### Start small to ensure correct model implementation and characterize convergence/overfitting dynamics
<!-- .element: class="fragment" -->
##### Scale up to reduce iteration time with GPU accelerators
<!-- .element: class="fragment" -->
##### Before scaling up, _know why you're scaling up!_
<!-- .element: class="fragment" -->
Note:
- My model can overfit small subset of data
- A small model is underfitting on whole dataset
- Therefore, a larger model is likely to fit the data better, but we need to tune it so we make sure it doesn't overfit.
---
## Next Steps 👟 👟
##### Horovod MPI Operator for Multi-node Training
##### Improve cost-effectiveness with interruptible instances
---
## Announcing
##### Flyte OSS Sync ⚡️ Lightning Talks! ⚡️
##### https://forms.gle/ZTWGeUCUyY9u2rjy9
{"metaMigratedAt":"2023-06-16T10:39:53.468Z","metaMigratedFrom":"YAML","title":"OSS Sync - Training Parallelized, GPU-accelerated Pytorch Models on Flyte","breaks":true,"description":"View the slide with \"Slide Mode\".","slideOptions":"{\"theme\":\"white\",\"fragments\":true,\"spotlight\":{\"enabled\":false}}","contributors":"[{\"id\":\"e8a53bc9-5dce-45e1-8810-ea537e4b141b\",\"add\":19368,\"del\":7542}]"}