# Training GNNs using multiple GPUs (Preview)
This tutorial explains how to train Graph Neural Networks (GNNs) on large graphs using multiple GPUs. Readers are required to be already familiar with the basic concepts of GNNs, what is message passing and how to implement a GNN in DGL. You will learn, in this tutorial,
* how to train GNNs by sampling from a large graph,
* how to speed up the training by parallelizing it on multiple GPUs,
* and how to implement more advanced GNNs on a sampled graph.
## Starting from the basics: full graph training
Recall that a GNN defines the following message passing computation:
$$
\begin{array}{l}
m_{v}^{(l)} &= \bigoplus_{u\in\mathcal{N}(v)}M^{(l)}(h_u^{(l-1)},h_v^{(l-1)},h_{uv}^{(l-1)}) \\
h_v^{(l)} &= U^{(l)}(h_v^{(l-1)}, m_v^{(l)})
\end{array}
$$
, where $M^{(l)}$, $\bigoplus$, $U^{(l)}$ are message, reduce and update functions respectively responsible for calculating, aggregating and consuming messages to produce new node representations.
This tutorial uses one instantiation of the above equations as below:
$$
\begin{array}{l}
m_{v}^{(l)} &= \frac{1}{|\mathcal{N}(v)|}\sum_{u\in\mathcal{N}(v)}h_u^{(l-1)} \\
h_v^{(l)} &= \sigma (W^{(l)}m_v^{(l)})
\end{array}
$$
Here, $M^{(l)}$ just uses the source node feature as message, $\bigoplus$ averages the messages and $U^{(l)}$ applies an affine transformation following by a non-linear activation function. The following code implements this model in DGL with PyTorch backend.
```python
import torch as th
import torch.nn as nn
class NodeUpdate(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdate, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, nodes):
h = nodes.data['h']
h = self.linear(h)
if self.activation is not None:
h = self.activation(h)
return {'h': h}
class GNN(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
**kwargs):
super(GNN, self).__init__(**kwargs)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
# input layer
self.layers.append(
NodeUpdate(in_feats, n_hidden, activation))
# hidden layers
for i in range(1, n_layers - 1):
self.layers.append(
NodeUpdate(n_hidden, n_hidden, activation))
# output layer
self.layers.append(
NodeUpdate(n_hidden, n_classes))
def forward(self, g):
h = g.ndata['features']
for i, layer in enumerate(self.layers):
h = self.dropout(h)
g.ndata['h'] = h
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'), layer)
h = g.ndata['h']
return h
```
You can apply the model on the Reddit graph dataset following the semi-supervised setting proposed in the research paper [Inductive Representation Learning on Large Graphs](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf). Specifically, the dataset includes a graph whose nodes are posts and edges are the co-author interactions. Besides, it also contains the features of post nodes and their subreddit categories acting as labels. During training, only a part of the node labels are revealed and the goal is to predict those are not. The code below shows the complete training pipeline including data preparation, graph construction, model definition and training loop.
```python
import dgl
from dgl.data import RedditDataset
def run():
dev_id = 0
# Number of GCN layers
L = 2
# Number of hidden units of a fully connected layer
n_hidden = 64
# Dropout probability
dropout = 0.2
# Number of epochs
num_epochs = 100
# Prepare data
data = RedditDataset(self_loop=True)
train_nid = th.LongTensor(np.nonzero(data.train_mask)[0])
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
# Define model and optimizer
model = GNN(in_feats, n_hidden, n_classes, L, F.relu, dropout)
model = model.to(dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=0.03)
th.cuda.synchronize()
# Training loop
batch_labels = labels[train_nid].to(dev_id)
g.ndata['features'] = g.ndata['features'].to(dev_id)
for epoch in range(num_epochs):
# forward
pred = model(g)[train_nid]
# compute loss
loss = loss_fcn(pred, batch_labels)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
## Mini-batch training on graphs
<!--
* What is mini-batch training on graphs and what are the challenges?
* Brief introduction of Nodeflow
* The change to the training loop
-->
The above training method assumes the entire graph fits into one GPU memory, which is not practical for real world data. To address the problem, the solution performs optimization on a sampled subset called mini-batch, iteratively. However, unlike the mini-batch training in traditional computer vision and natural language processing, graph-structured data does not have the *independent and identically distributed* (I.I.D.) assumption, thus requiring different treatment.

The basic idea of mini-batch training on graphs is letting each mini-batch also contains the subgraph that it depends on. In the above picture, if the current mini-batch has node #1, which we call a *seeded node*, the batch also includes the neighbors that contribute to the representation of the seeded node. If the GNN model contains two layers, the dependency therefore forms a DAG of two levels. However, computing over this complete DAG could be very costly when the GNN model is deep or the graph has small world property, both leading to large first level size. Therefore, many research papers ([FastGCN](https://arxiv.org/abs/1801.10247), [Variance Reduction](https://arxiv.org/abs/1710.10568), etc.) have proposed to down-sample the DAG so its overall size is manageable.
In DGL, this DAG structure is called `Nodeflow` and we provide a convenient interface `NeighborSampler` to generate it from a large graph. The following code creates a sampler from the Reddit dataset.
```python
import dgl
from dgl.data import RedditDataset
from dgl.contrib.sampling import NeighborSampler
def run_minibatch():
dev_id = 0
# Number of GCN layers
L = 2
# Number of hidden units of a fully connected layer
n_hidden = 64
# Dropout probability
dropout = 0.2
# Number of epochs
num_epochs = 100
# Prepare data
data = RedditDataset(self_loop=True)
train_nid = th.LongTensor(np.nonzero(data.train_mask)[0])
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
# Create sampler
sampler = NeighborSampler(
g, # the graph structure
seed_nodes=train_nid, # the set of nodes to draw samples from
batch_size=1000, # batch size
num_neighbors=10, # sampling fan-out of each layer
neighbor_type='in', # use in-edge neighbors
shuffle=True, # shuffle the seeded nodes at each epoch
num_hops=L, # number of layers
num_workers=4)
```
A sampler behaves like an iterator drawing batches of samples from the provided seeded nodes until the set is exhausted. Because only a part of the labeled nodes are revealed during training, the set of seeded nodes is equal to the revealed set.
The following code shows the training loop using a sampler.
```python
def run_minibatch():
...
# Define model and optimizer
model = GNNMinibatch(in_feats, n_hidden, n_classes, L, F.relu, dropout)
model = model.to(dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=0.03)
th.cuda.synchronize()
# Training loop
for epoch in range(num_epochs):
for nf in sampler:
# get the feature and label of the current mini-batch
nf.copy_from_parent()
nf.layers[0].data['features'] = nf.layers[0].data['features'].to(dev_id)
batch_nids = nf.layer_parent_nid(-1).to(dev_id)
batch_labels = labels[batch_nids].to(dev_id)
# forward
pred = model(nf)
# compute loss
loss = loss_fcn(pred, batch_labels)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
```
The difference between a mini-batch and a full graph training for graphs:
* A full graph training only has one loop over many epochs while a mini-batch training has two nested loops. The inner one loops over the mini-batches generated by a sampler.
* In the beginning of the inner loop, `nf.copy_from_parent()` extracts the features used by the current mini-batch and stores them in the `Nodeflow` data structure.
* `nf.layer_parent_nid(-1)` finds the node ids of the last layer w.r.t. the full graph. The result is equal to the sampled node ids of the current mini-batch.
The following code defines the GNN model working on a mini-batch.
```python
class GNNMinibatch(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
**kwargs):
super(GCNSampling, self).__init__(**kwargs)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
# input layer
self.layers.append(
NodeUpdate(in_feats, n_hidden, activation))
# hidden layers
for i in range(1, n_layers - 1):
self.layers.append(
NodeUpdate(n_hidden, n_hidden, activation))
# output layer
self.layers.append(
NodeUpdate(n_hidden, n_classes))
def forward(self, nf):
h = nf.layers[0].data['features']
for i, layer in enumerate(self.layers):
h = self.dropout(h)
nf.layers[i].data['h'] = h
nf.block_compute(i, fn.copy_u('h', 'm'), fn.mean('m', 'h'), layer)
h = nf.layers[i+1].data['h']
return h
```
Compared with the GNN model for full graph training, the majority is the same. The differences are in the `forward` function:
* The input is read from `nf.layers[i]` and the output is written to `nf.layers[i+1]`.
* `nf.block_compute(i, fn.copy_u('h', 'm'), fn.mean('m', 'h'), layer)` replaces `g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'), layer)`. `nf.block_compute` triggers message passing computation on all the sampled edges between layer $i$ and $i+1$. The input node features are read from layer $i$ while the outputs are written to layer $i+1$.
## Multi-GPU mini-batch training
To speed up the training, we can further divide the DAG into independent sub-DAGs and dispatch them to multiple GPUs for parallel training. The figure below demonstrates one example. A mini-batch that contains two seeded nodes #1 and #8 are splitted into two smaller ones each contains one seeded node.

The following code shows how to implement this parallel training using PyTorch's multi-process package. Essentially, it forks multiple processes and lets each in charge of training on one GPU. To read more about this practice, please refer to PyTorch's [tutorial about multi-GPU training](https://pytorch.org/docs/stable/notes/multiprocessing.html#best-practices-and-tips).
```python
import dgl
from dgl.data import RedditDataset
from dgl.contrib.sampling import NeighborSampler
def run_minibatch_mp(dev_id, nprocs, data):
# Number of GNN layers
L = 2
# Number of hidden units of a fully connected layer
n_hidden = 64
# Dropout probability
dropout = 0.2
# Number of epochs
num_epochs = 100
# Setup multi-process
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
th.distributed.init_process_group(
backend="nccl",
init_method=dist_init_method,
world_size=nprocs,
rank=dev_id)
th.set_num_threads(4)
# Unpack data
train_nid, in_feats, labels, n_classes, g = data
# Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id]
# Create sampler
sampler = NeighborSampler(
g, # the graph structure
seed_nodes=train_nid, # the set of nodes to draw samples from
batch_size=1000, # batch size
num_neighbors=10, # number of neighbors to include
neighbor_type='in', # use in-edge neighbors
shuffle=True, # shuffle the
num_hops=L, # number of layers
num_workers=4)
# Define model and optimizer
model = GNNMinibatch(in_feats, n_hidden, n_classes, L, F.relu, dropout)
model = model.to(dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=0.03)
th.cuda.synchronize()
# Training loop
for epoch in range(num_epochs):
for nf in sampler:
# get the feature and label of the current mini-batch
nf.copy_from_parent()
nf.layers[0].data['features'] = nf.layers[0].data['features'].to(dev_id)
batch_nids = nf.layer_parent_nid(-1).to(dev_id)
batch_labels = labels[batch_nids].to(dev_id)
# forward
pred = model(nf)
# compute loss
loss = loss_fcn(pred, batch_labels)
# backward
optimizer.zero_grad()
loss.backward()
# aggregate gradients
for param in model.parameters():
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
optimizer.step()
```
The above code shares most of the logic with the mini-batch training code except some slight changes:
* In the beginning of the function, setup the multi-process environment by `th.distributed.init_process_group`.
* The seeded nodes (i.e., `train_nid` in this example) are splitted equally. Each GPU generates samples from their own portion.
* In the inner training loop, aggregate the gradients computed on multiple GPUs by `th.distributed.all_reduce`.
Use the `torch.multiprocessing` package to launch multiple processes, each executing the `run_minibatch_mp` function in parallel.
```python
# Prepare reddit data
data = RedditDataset(self_loop=True)
train_nid = th.LongTensor(np.nonzero(data.train_mask)[0])
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
# Pack data
data = train_nid, in_feats, labels, n_classes, g
n_gpus = 4 # 4 GPUs
mp = th.multiprocessing
mp.spawn(run_minibatch_mp, args=(n_gpus, data), nprocs=n_gpus)
```
Note that it loads and constructs the graph in the main process. The spawned worker processes share the memory space with the parent process.
## More advanced GNN variant
Recall that the example GNN we used so far is as follows:
$$
\begin{array}{l}
m_{v}^{(l)} &= \frac{1}{|\mathcal{N}(v)|}\sum_{u\in\mathcal{N}(v)}h_u^{(l-1)} \\
h_v^{(l)} &= \sigma (W^{(l)}m_v^{(l)})
\end{array}
$$
The new embedding of node $v$ only depends on the the previous embeddings of its neighbor nodes, which can be summarized as the following equation:
$$
h_v^{(l)} = f_\theta^{(l)}(\{h_u^{(l-1)} | u\in\mathcal{N}(v)\})
$$
Sometimes, we wish $h_v^{(l)}$ to also leverage its own embedding in the previous layer, resulting in the following equation:
$$
h_v^{(l)} = f_\theta^{(l)}(\{h_u^{(l-1)} | u\in\mathcal{N}(v)\}\cup \{h_v^{(l-1)}\})
$$
As an example, the popular GraphSAGE model (from the paper [Inductive Representation Learning on Large Graphs](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)) is formulated by the equations below.
$$
h_v^{(l)} = \sigma \left(W^{(l)}\left( h_v^{(l-1)} || \frac{1}{ |\mathcal{N}(v)|}\sum_{u\in\mathcal{N}(v)}h_u^{(l-1)}\right) \right)
$$
Expanding this recursive function, we get the following dependency graph.

It indicates that at each step $i$, instead of receiving messages only on nodes of layer $i$, all nodes from layers $i$ to $l$ should receive messages to produce their embeddings. The implementation is as follows.
```python
class SAGENodeUpdate(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(SAGENodeUpdate, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, nodes):
h = th.cat([nodes.data['h'], nodes.data['h_n']], dim=1)
h = self.linear(h)
if self.activation is not None:
h = self.activation(h)
return {'h_new': h}
class SAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
**kwargs):
super(SAGE, self).__init__(**kwargs)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
# input layer
self.layers.append(
SAGENodeUpdate(in_feats * 2, n_hidden, activation))
# hidden layers
for i in range(1, n_layers - 1):
self.layers.append(
SAGENodeUpdate(n_hidden * 2, n_hidden, activation))
# output layer
self.layers.append(
SAGENodeUpdate(n_hidden * 2, n_classes))
def forward(self, nf):
for i in range(nf.num_layers):
nf.layers[i].data['h'] = nf.layers[i].data['features']
nf.layers[i].data['h_new'] = nf.layers[i].data['features']
for i in range(len(self.layers)): # step i
for j in range(i, len(self.layers)): # trigger message passing on all layers from i~l
nf.layers[j].data['h'] = self.dropout(nf.layers[j].data['h'])
nf.block_compute(j, fn.copy_u('h', 'm'), fn.mean('m', 'h_n'), self.layers[i])
for j in range(i, len(self.layers)): # update the embeddings for all layers from i~l
nf.layers[j+1].data['h'] = nf.layers[j+1].data['h_new']
return nf.layers[-1].data['h']
```
## Results
You can access the full training script [here](https://github.com/dmlc/dgl/tree/multi-gpu/examples/pytorch/sampling/multi-gpu). We test it on an EC2 p3.16xlarge instance with 8 NVIDIA V100 GPUs. Please click the link for more details about the setup and running instructions. The experiment shows that DGL is able to scale linearly up to 4 GPUs and achieve at most 6.9x speed up on 8 GPUs.
