Try   HackMD

How to implement a GNN

import torch
from torch_scatter import scatter

We will use this graph:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More β†’

num_nodes = 5
num_edges = 6
num_edge_types = 3
edge_index = torch.LongTensor([
    [0,1], [0,4], [2,0], [2,3], [3,1], [2,4]
]).t()

edge_index
​​​​tensor([[0, 0, 2, 2, 3, 2],
​​​​        [1, 4, 0, 3, 1, 4]])
embed_dim = 10

x = torch.randn(num_nodes, embed_dim)
x.size()
​​​​torch.Size([5, 10])

Define the MessagePassing class

import inspect
import torch.nn as nn


class MessagePassing(nn.Module):
    def __init__(self):
        super(MessagePassing, self).__init__()

        # get the list of arguments as a list of `string`
        self.message_args = inspect.getfullargspec(self.message)[0][1:]
        self.update_args = inspect.getfullargspec(self.update)[0][2:]

    def propagate(self, aggr, edge_index, **kwargs):
        '''
        edge_index.size(): [2, number_of_edges]
        '''
        assert aggr in ['sum', 'mean', 'max']
        kwargs['edge_index'] = edge_index

        dim_size = None
        message_args = []
        for arg in self.message_args:
            if arg.endswith('_i') or arg.endswith('_j'):    # will almost always be `x_i` or `x_j`
                x = kwargs[arg[:-2]]                        # get `x`, has shape [number_of_nodes, embed_dim]
                dim_size = x.size(0)
                j = 0 if arg.endswith('_i') else 1
                x_j = x[edge_index[j]]
                message_args.append(x_j)
            else:
                message_args.append(kwargs[arg])

        update_args = [kwargs[arg] for arg in self.update_args]

        out = self.message(*message_args)
        out = scatter(out, edge_index[0], dim=0, dim_size=dim_size, reduce=aggr)
        return self.update(out, *update_args)
            
    def message(self):
        raise NotImplementedError

    def update(self):
        raise NotImplementedError

Create a nn.Module that inherits MessagePassing

class MyGNN(MessagePassing):
    def __init__(self, embed_dim=embed_dim):
        super(MyGNN, self).__init__()
        self.embed_dim = embed_dim
        self.message_output = None
        self.update_output = None
        
    def message(self, x_j, a_random_argument):
        '''
        `x_j` has shape (num_edges, embed_dim)
        `a_random_argument` can be any argument. In this case it is a tensor of shape `(embed_dim,)`
        '''
        print(f'\tInside message method. `x_j.size()` is {x_j.size()}')
        output = x_j * a_random_argument.view(1,-1)
        self.message_output = output
        return output
    
    def update(self, out):
        print(f'\tInside update method. `out.size()` is {out.size()}')
        output = out
        self.update_output = output
        return output
    
    def forward(self, x, edge_index):
        return self.propagate('mean', edge_index, x=x, a_random_argument=torch.ones(self.embed_dim)*2)
model = MyGNN()
model_output = model(x, edge_index)
model_output.size()
​​​​	Inside message method. `x_j.size()` is torch.Size([6, 10])
​​​​	Inside update method. `out.size()` is torch.Size([5, 10])


​​​​torch.Size([5, 10])
model.message_output
​​​​tensor([[-3.9773, -1.1126, -0.0140,  3.8670,  1.2074, -1.7092,  0.3560,  3.2655,
​​​​         -2.3335, -4.1227],
​​​​        [-1.3679, -3.1507, -1.1292,  1.1108, -2.8639,  1.0764, -1.2942, -2.1353,
​​​​         -1.2333, -0.1250],
​​​​        [ 0.2759, -3.5196, -0.0355, -0.4968,  0.6988,  0.3293, -1.4655,  1.2332,
​​​​         -0.5728, -0.9543],
​​​​        [ 0.1292,  0.8452, -1.5253,  4.7322, -0.6898, -0.2914, -0.8749,  0.1460,
​​​​          0.2724,  1.6101],
​​​​        [-3.9773, -1.1126, -0.0140,  3.8670,  1.2074, -1.7092,  0.3560,  3.2655,
​​​​         -2.3335, -4.1227],
​​​​        [-1.3679, -3.1507, -1.1292,  1.1108, -2.8639,  1.0764, -1.2942, -2.1353,
​​​​         -1.2333, -0.1250]])
model.update_output
​​​​tensor([[-2.6726, -2.1317, -0.5716,  2.4889, -0.8283, -0.3164, -0.4691,  0.5651,
​​​​         -1.7834, -2.1239],
​​​​        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
​​​​          0.0000,  0.0000],
​​​​        [-0.3209, -1.9417, -0.8967,  1.7820, -0.9516,  0.3714, -1.2115, -0.2520,
​​​​         -0.5112,  0.1769],
​​​​        [-3.9773, -1.1126, -0.0140,  3.8670,  1.2074, -1.7092,  0.3560,  3.2655,
​​​​         -2.3335, -4.1227],
​​​​        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
​​​​          0.0000,  0.0000]])

Now let's run these results manually

Before calling message, propagate moves messages from nodes to edges.

This can be done in two ways:

  1. For each edge, move the message (node embedding) of the source node: happens if x_i is an argument of message method
  2. For each edge, move the message (node embedding) of the target node: happens if x_j is an argument of message method

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More β†’

edge_index
​​​​tensor([[0, 0, 2, 2, 3, 2],
​​​​        [1, 4, 0, 3, 1, 4]])
torch.equal(x[edge_index[1]] * 2, model.message_output)
​​​​True

Now let's modify the argument of message to x_i

class MySecondGNN(MessagePassing):
    def __init__(self, embed_dim=embed_dim):
        super(MySecondGNN, self).__init__()
        self.embed_dim = embed_dim
        self.message_output = None
        self.update_output = None
        
    def message(self, x_i, a_random_argument):
        '''
        `x_j` has shape (num_edges, embed_dim)
        `a_random_argument` can be any argument. In this case it is a tensor of shape `(embed_dim,)`
        '''
        print(f'\tInside message method. `x_j.size()` is {x_i.size()}')
        output = x_i * a_random_argument.view(1,-1)
        self.message_output = output
        return output
    
    def update(self, out):
        print(f'\tInside update method. `out.size()` is {out.size()}')
        output = out
        self.update_output = output
        return output
    
    def forward(self, x, edge_index):
        return self.propagate('mean', edge_index, x=x, a_random_argument=torch.ones(self.embed_dim)*2)
model = MySecondGNN()

model_output = model(x, edge_index)
model_output.size()
​​​​	Inside message method. `x_j.size()` is torch.Size([6, 10])
​​​​	Inside update method. `out.size()` is torch.Size([5, 10])





​​​​torch.Size([5, 10])
torch.equal(x[edge_index[0]] * 2, model.message_output)
​​​​True

Now let's take a look at what happens before update is called: the messages are put back from edges to nodes

Note that this is not one to one - a node can have many edges. As a result it can receive messages from many edges. Thus we need a way to aggregate these messages. This is exactly what the aggr function does in scatter.

edge_index
​​​​tensor([[0, 0, 2, 2, 3, 2],
​​​​        [1, 4, 0, 3, 1, 4]])

Messages received by node 0:

message_output = x[edge_index[0]] * 2
torch.equal(message_output[[0,1]].mean(dim=0), model.update_output[0])
​​​​True

Messages received by node 2:

torch.equal(message_output[[2,3,5]].mean(dim=0), model.update_output[2])
​​​​True

We can pass messages to all the nodes using scatter

scatter(message_output, edge_index[0], dim=0, dim_size=num_nodes, reduce='mean')
​​​​tensor([[ 0.2759, -3.5196, -0.0355, -0.4968,  0.6988,  0.3293, -1.4655,  1.2332,
​​​​         -0.5728, -0.9543],
​​​​        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
​​​​          0.0000,  0.0000],
​​​​        [-0.5084,  2.1698,  2.5749, -1.7214, -2.0338,  0.1578, -0.6327, -0.6667,
​​​​          2.7229,  2.0444],
​​​​        [ 0.1292,  0.8452, -1.5253,  4.7322, -0.6898, -0.2914, -0.8749,  0.1460,
​​​​          0.2724,  1.6101],
​​​​        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
​​​​          0.0000,  0.0000]])
torch.equal(scatter(message_output, edge_index[0], dim=0, dim_size=num_nodes, reduce='mean'), model.update_output)
​​​​True
tags: pytorch, gnn