tensor([[0, 0, 2, 2, 3, 2],
[1, 4, 0, 3, 1, 4]])
torch.Size([5, 10])
MessagePassing
classnn.Module
that inherits MessagePassing
Inside message method. `x_j.size()` is torch.Size([6, 10])
Inside update method. `out.size()` is torch.Size([5, 10])
torch.Size([5, 10])
tensor([[-3.9773, -1.1126, -0.0140, 3.8670, 1.2074, -1.7092, 0.3560, 3.2655,
-2.3335, -4.1227],
[-1.3679, -3.1507, -1.1292, 1.1108, -2.8639, 1.0764, -1.2942, -2.1353,
-1.2333, -0.1250],
[ 0.2759, -3.5196, -0.0355, -0.4968, 0.6988, 0.3293, -1.4655, 1.2332,
-0.5728, -0.9543],
[ 0.1292, 0.8452, -1.5253, 4.7322, -0.6898, -0.2914, -0.8749, 0.1460,
0.2724, 1.6101],
[-3.9773, -1.1126, -0.0140, 3.8670, 1.2074, -1.7092, 0.3560, 3.2655,
-2.3335, -4.1227],
[-1.3679, -3.1507, -1.1292, 1.1108, -2.8639, 1.0764, -1.2942, -2.1353,
-1.2333, -0.1250]])
tensor([[-2.6726, -2.1317, -0.5716, 2.4889, -0.8283, -0.3164, -0.4691, 0.5651,
-1.7834, -2.1239],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[-0.3209, -1.9417, -0.8967, 1.7820, -0.9516, 0.3714, -1.2115, -0.2520,
-0.5112, 0.1769],
[-3.9773, -1.1126, -0.0140, 3.8670, 1.2074, -1.7092, 0.3560, 3.2655,
-2.3335, -4.1227],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]])
message
, propagate
moves messages from nodes to edges.This can be done in two ways:
x_i
is an argument of message
methodx_j
is an argument of message
methodtensor([[0, 0, 2, 2, 3, 2],
[1, 4, 0, 3, 1, 4]])
True
message
to x_i
Inside message method. `x_j.size()` is torch.Size([6, 10])
Inside update method. `out.size()` is torch.Size([5, 10])
torch.Size([5, 10])
True
Note that this is not one to one - a node can have many edges. As a result it can receive messages from many edges. Thus we need a way to aggregate these messages. This is exactly what the aggr
function does in scatter
.
tensor([[0, 0, 2, 2, 3, 2],
[1, 4, 0, 3, 1, 4]])
0
:True
2
:True
scatter
tensor([[ 0.2759, -3.5196, -0.0355, -0.4968, 0.6988, 0.3293, -1.4655, 1.2332,
-0.5728, -0.9543],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[-0.5084, 2.1698, 2.5749, -1.7214, -2.0338, 0.1578, -0.6327, -0.6667,
2.7229, 2.0444],
[ 0.1292, 0.8452, -1.5253, 4.7322, -0.6898, -0.2914, -0.8749, 0.1460,
0.2724, 1.6101],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]])
True
pytorch
, gnn