![](https://i.imgur.com/zJXMQCp.png) PyTorch-Ignite distributed module === ###### tags: `PyTorch-Ignite` `distributed training` :::info Summary 1. **Introduction** 2. **PyTorch-Ignite Distributed** 1. Focus on the use of PyTorch-Ignite's `auto_*` methods 3. **Use Case** 4. **Code Snippets** 1. PyTorch-Ignite vs DDP vs Horovod vs XLA 5. **Running Distributed Code** ::: Introduction === PyTorch offers a distributed communication submodule for writing and running parallel applications on multiple devices and machines. The native interface provides commonly used collective operations and allows to address multi-CPU and multi-GPU computations seamlessly using the [torch DistributedDataParallel](https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html) module and the well-known [mpi](https://www.open-mpi.org/), [gloo](https://github.com/facebookincubator/gloo) and [nccl](https://github.com/NVIDIA/nccl) backends. Moreover, PyTorch on XLA (Accelerated Linear Algebra) devices, like TPUs, is now supported with the [pytorch/xla](https://github.com/pytorch/xla) package. However, writing distributed training code working both on GPUs and TPUs is not a trivial task due to some API specificities. [PyTorch-Ignite's](https://github.com/pytorch/ignite) [`ignite.distributed`](https://pytorch.org/ignite/distributed.html) (`idist`) submodule introduced in version [v0.4.0 (July 2020)](https://github.com/pytorch/ignite/releases/tag/v0.4.0.post1) quickly turns your sequential code into its distributed version. Thus, you will be able to run the same version of your code across all supported backends seamlessly: - backends from native torch distributed configuration: `nccl`, `gloo`, `mpi` - [Horovod](https://horovod.readthedocs.io/en/stable/) backend with `gloo` or `nccl` support - XLA on TPUs via [pytorch/xla](https://github.com/pytorch/xla) In this blog post we will compare the difference of implementation between torch native distributed code using multiple frameworks with the PyTorch-Ignite API. We have focused on the ease of use of the `auto_*` methods which help you adapt your code to existing distributed configurations provided the model, optimizer and data loaders. Code snippets, as well as commands for running all the scripts, are given. Then we will also cover several ways of spawning processes via torch native `torch.multiprocessing.spawn` and also via multiple distributed launchers in order to highlight how Pytorch-Ignite's `idist` can handle it without any changes to the code, in particular: - [torch.multiprocessing.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) - [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility) - [horovdrun](https://horovod.readthedocs.io/en/stable/running_include.html) - [slurm](https://slurm.schedmd.com/srun.html) Note that more information on launchers experiments can be found [here](https://github.com/sdesrozis/why-ignite). :fire: Pytorch-Ignite Unified Distributed API === Writing specific code is needed in order to be able to call multiple distributed frameworks APIs. These modifications can be tedious especially if you would like to test your code on different hardware configurations. Pytorch-Ignite's `idist` will do all the work for you, owing to the high-level helper methods. ## :mag: Focus on the helper `auto_*` methods: - [auto_model()](https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_model) This method adapts the logic for non-distributed and available distributed configurations. It will wrap [`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) and [`horovod.torch.broadcast_parameters`](https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.broadcast_parameters), and is in charge of handling the complexity associated with model upload and target device selection for all the available backends. Additionally, it is also compatible with [NVIDIA/apex](https://github.com/NVIDIA/apex) via the model returned by [`amp.initialize`](https://nvidia.github.io/apex/amp.html?#apex.amp.initialize). - [auto_optim()](https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_model) This method adapts the optimizer logic for non-distributed and available distributed configurations seamlessly. Specifically, for XLA distributed configuration, we create a new class that inherits from the provided optimizer. The goal is to override the `step()` method with a specific [`xm.optimizer_step`](https://pytorch.org/xla/release/1.8/index.html#torch_xla.core.xla_model.optimizer_step) implementation. For Horovod's distributed configuration, the optimizer is wrapped with Horovod DistributedOptimizer, and its state is broadcasted from rank 0 to all other processes. `auto_optim()` in this case is replacing [`horovod.torch.DistributedOptimizer`](https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.DistributedOptimizer). - [auto_dataloader()](https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader) This method adapts the data loading logic for non-distributed and available distributed configurations seamlessly on target devices. Specifically, it wraps [`torch.utils.data.distributed.DistributedSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler), [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) and [`torch_xla.distributed.parallel_loader.MpDeviceLoader`](https://github.com/pytorch/xla/blob/815197139b94e5655ed6b347f48864e73dc73011/torch_xla/distributed/parallel_loader.py#L171). Additionally, `auto_dataloader()` automatically scales the batch size according to the distributed configuration context resulting in a general way of loading sample batches on multiple devices. Use case === We have taken a realistic training code in this use case by loading a `torchvision.models.resnet50` model and using a random dataset. This example is an adaptation of the PyTorch code snippet found in the official [Horovod documentation](https://horovod.readthedocs.io/en/stable/pytorch.html). Note that these code snippets are not as complex as the complete training code to focus on the API specificities of each of the distributed backends (if you would like to see Pytorch-Ignite on a production-grade example [see here](https://github.com/pytorch/ignite/tree/master/examples/contrib/cifar10)). Code snippets === These code snippets highlight the API's specificities of each of the distributed backends on the same use case as compared to the `idist` API. Torch native code is available for DDP, Horovod, and for XLA/TPU devices. For XLA one can run the samples via this [notebook](https://colab.research.google.com/drive/1I1MvDG4ikiGIg8DJH1ZWBbvlL43ttH8z?usp=sharing) in order to reproduce the experiments on TPUs. PyTorch-Ignite's unified code snippet can be run with the standard Torch backends like `gloo` and `nccl` and also with Horovod and XLA for TPU devices. Note that the code is less verbose, however, the user still has full control of the training loop. The complete source code of these experiments can be found [here](https://github.com/fco-dv/idist-snippets). ### Torch native Distributed Data Parallel - Horovod - XLA/TPUs <div> <table> <tr> <th> PyTorch-Ignite </th> <th> PyTorch DDP </th> <th> Horovod </th> <th> Torch XLA </th> </tr> <tr> <td> ```python import ... def training(rank, config): # Specific ignite.distributed print( idist.get_rank(), ": run with config:", config, "- backend=", idist.backend(), "- world size", idist.get_world_size(), ) print(idist.get_rank(), " with seed ", torch.initial_seed()) device = idist.device() # Data preparation: dataset = ... # Specific ignite.distributed train_loader = idist.auto_dataloader(dataset, batch_size=config["batch_size"]) # Model, criterion, optimizer setup model = idist.auto_model(wide_resnet50_2(num_classes=100)) criterion = NLLLoss() optimizer = idist.auto_optim(SGD(model.parameters(), lr=0.01)) # Training loop log param log_interval = config["log_interval"] def _train_step(engine, batch): data = batch[0].to(device) target = batch[1].to(device) ... output = model(data) ... loss_val = ... return loss_val # Running the _train_step function on whole batch_data iterable only once trainer = Engine(_train_step) # Specific Pytorch-Ignite trainer.run(train_loader, max_epochs=1) if __name__ == "__main__": parser = argparse.ArgumentParser("Pytorch Ignite - idist") parser.add_argument("--backend", type=str, default="nccl") parser.add_argument("--nproc_per_node", type=int) ... # Specific ignite.distributed with idist.Parallel(backend=args_parsed.backend, **spawn_kwargs) as parallel: parallel.run(_mp_train, config) ``` </td> <td> ```python= import ... def training(rank, world_size, backend, config): # Specific torch.distributed dist.init_process_group( backend, init_method="tcp://0.0.0.0:2233", world_size=world_size, rank=rank ) print(dist.get_rank(), ": run with config:", config, " - backend=", backend) device = None if backend == "nccl": torch.cuda.set_device(rank) device = "cuda" else: device = "cpu" # Data preparation dataset = ... # Specific torch.distributed train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_loader = torch.utils.data.DataLoader( dataset, batch_size=int(config["batch_size"] / world_size), num_workers=1, sampler=train_sampler, ) # Model, criterion, optimizer setup model = wide_resnet50_2(num_classes=100).to(device) criterion = NLLLoss() optimizer = SGD(model.parameters(), lr=0.01) # Specific torch.distributed if backend == "nccl": model = DDP(model, device_ids=[rank]) elif backend == "gloo": model = DDP(model) # Training loop log param log_interval = config["log_interval"] def _train_step(batch_idx, data, target): data = data.to(device) target = target.to(device) ... output = model(data) ... loss_val = ... return loss_val # Running _train_step for n_epochs n_epochs = 1 for epoch in range(n_epochs): for batch_idx, (data, target) in enumerate(train_loader): _train_step(batch_idx, data, target) # Specific torch.distributed dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser("Torch Native - DDP") parser.add_argument("--backend", type=str, default="nccl") parser.add_argument("--nproc_per_node", type=int, default=2) ... args = (args_parsed.nproc_per_node, args_parsed.backend, config) # Specific torch.distributed start_processes( training, args=args, nprocs=args_parsed.nproc_per_node, start_method="spawn" ) ``` </td> <td> ```python import ... def training(world_size, backend, config): # Specific hvd hvd.init() print({hvd.local_rank()}, ": run with config:", config, " - backend=", backend) device = None if backend == "nccl": # Pin GPU to be used to process local rank (one GPU per process) # Specific hvd torch.cuda.set_device(hvd.local_rank()) device = "cuda" else: device = "cpu" # Data preparation dataset = ... # Specific hvd train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=hvd.size(), rank=hvd.rank() ) train_loader = torch.utils.data.DataLoader( dataset, batch_size=int(config["batch_size"] / hvd.size()), num_workers=1, sampler=train_sampler, ) # Model, criterion, optimizer setup model = wide_resnet50_2(num_classes=100).to(device) criterion = NLLLoss().to(device) optimizer = SGD(model.parameters(), lr=0.001) # Specific hvd # Add Horovod Distributed Optimizer optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters() ) # Specific hvd # Broadcast parameters from rank 0 to all other processes. hvd.broadcast_parameters(model.state_dict(), root_rank=0) # Training loop log param log_interval = config["log_interval"] def _train_step(batch_idx, data, target): data = data.to(device) target = target.to(device) ... output = model(data) ... loss_val = ... return loss_val # Running _train_step for n_epochs n_epochs = 1 for epoch in range(n_epochs): for batch_idx, (data, target) in enumerate(train_loader): _train_step(batch_idx, data, target) # Specific hvd hvd.shutdown() if __name__ == "__main__": parser = argparse.ArgumentParser("Torch Native - Horovod") parser.add_argument("--backend", type=str, default="gloo") parser.add_argument("--nproc_per_node", type=int, default=2) ... args = (args_parsed.nproc_per_node, args_parsed.backend, config) # Specific hvd run(training, args=args, use_gloo=True, np=args_parsed.nproc_per_node) ``` </td> <td> ```python= import ... def training(rank, world_size, backend, config): # Specific xla print(xm.get_ordinal(), ": run with config:", config, "- backend=", backend) device = xm.xla_device() print(xm.get_ordinal(), " with seed ", torch.initial_seed()) # Data preparation dataset = ... # Specific xla train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), ) train_loader = torch.utils.data.DataLoader( dataset, batch_size=int(config["batch_size"] / xm.xrt_world_size()), num_workers=1, sampler=train_sampler, ) # Specific xla para_loader = pl.MpDeviceLoader(train_loader, device) # Model, criterion, optimizer setup model = wide_resnet50_2(num_classes=100).to(device) criterion = NLLLoss() optimizer = SGD(model.parameters(), lr=0.01) # Training loop log param log_interval = config["log_interval"] def _train_step(batch_idx, data, target): data = data target = target ... output = model(data) ... loss_val = ... return loss_val # Running _train_step for n_epochs n_epochs = 1 for epoch in range(n_epochs): for batch_idx, (data, target) in enumerate(para_loader): _train_step(batch_idx, data, target) if __name__ == "__main__": parser = argparse.ArgumentParser("Torch Native - XLA") parser.add_argument("--backend", type=str, default="xla-tpu") parser.add_argument("--nproc_per_node", type=int, default=8) ... args = (args_parsed.nproc_per_node, args_parsed.backend, config) # Specific xla xmp.spawn(training, args=args, nprocs=args_parsed.nproc_per_node) ``` </td> </tr> </table> </div> --- Running Distributed Code === PyTorch-Ignite's `idist` also unifies the distributed codes launching method and makes the distributed configuration setup easier with the [`ignite.distributed.launcher.Parallel (idist Parallel)`](https://pytorch.org/ignite/distributed.html#ignite.distributed.launcher.Parallel) context manager. This context manager has the capability to either spawn `nproc_per_node` (passed as a script argument) child processes and initialize a processing group according to the provided backend or rest on tools like `torch.distributed.launch`, `slurm`, `horovodrun` by initializing the processing group given the `backend` argument only in a general way. ### With `torch.multiprocessing.spawn` Here `idist Parallel` is using the native torch `torch.multiprocessing.spawn` method under the hood in order to run the distributed configuration. In this case `nproc_per_node` is passed as spawn arguments. - Running multiple distributed configurations with one code: ```commandline # Running with gloo python -u ignite_idist.py --nproc_per_node 2 --backend gloo # Running with nccl python -u ignite_idist.py --nproc_per_node 2 --backend nccl # Running with horovod with gloo controller ( gloo or nccl support ) python -u ignite_idist.py --backend horovod --nproc_per_node 2 # Running on xla/tpu python -u ignite_idist.py --backend xla-tpu --nproc_per_node 8 --batch_size 32 ``` ### With Distributed launchers PyTorch-Ignite's `idist Parallel` context manager is also compatible with multiple distributed launchers. #### With torch.distributed.launch Here we are using the `torch.distributed.launch` script in order to spawn the processes: ```commandline python -m torch.distributed.launch --nproc_per_node 2 --use_env ignite_idist.py --backend gloo ``` #### With horovodrun ```commandline horovodrun -np 2 -H hostname1:8,hostname2:8 python ignite_idist.py --backend horovod ``` :::warning In order to run this example and to avoid the installation procedure, you can pull one of PyTorch-Ignite's [docker image with pre-installed Horovod](https://github.com/pytorch/ignite/blob/master/docker/hvd/Dockerfile.hvd-base). It will include Horovod with `gloo` controller and `nccl` support. ```commandline docker run --gpus all -it -v $PWD:/workspace/project --network=host --shm-size 16G pytorchignite/hvd-vision:latest /bin/bash cd project ... ``` ::: #### With slurm The same result can be achieved by using `slurm` without a single modification of the code: ```commandline srun --nodes=2 --ntasks-per-node=2 --job-name=pytorch-ignite --time=00:01:00 --partition=gpgpu --gres=gpu:2 --mem=10G python ignite_idist.py --backend nccl ``` ### Comparison with Torch native run methods In order to run the same training loop on different backends without `idist` you would have to use the different native torch snippets and associate a specific launch method for each of them. Here is how you could have done that: #### Torch native DDP - Run the `torch native` snippet with different backends: ```commandline # Running with gloo python -u tn_ddp.py --nproc_per_node 2 --backend gloo # Running with nccl python -u tn_ddp.py --nproc_per_node 2 --backend nccl ``` #### Horovod - Run `horovod native` with `gloo` controller and `nccl`/`gloo` supports ```commandline # Running with gloo support python -u tn_hvd.py --nproc_per_node 2 # Running with nccl support python -u tn_hvd.py --backend nccl --nproc_per_node 2 ``` #### XLA/TPU devices - Run `torch xla native` snippet on tpa/xlu with : ```commandline # Run with a default of 8 processes python -u tn_xla.py ```