ββββtensor([[0, 0, 2, 2, 3, 2],
ββββ [1, 4, 0, 3, 1, 4]])
ββββtorch.Size([5, 10])
MessagePassing
classnn.Module
that inherits MessagePassing
ββββ Inside message method. `x_j.size()` is torch.Size([6, 10])
ββββ Inside update method. `out.size()` is torch.Size([5, 10])
ββββtorch.Size([5, 10])
ββββtensor([[-3.9773, -1.1126, -0.0140, 3.8670, 1.2074, -1.7092, 0.3560, 3.2655,
ββββ -2.3335, -4.1227],
ββββ [-1.3679, -3.1507, -1.1292, 1.1108, -2.8639, 1.0764, -1.2942, -2.1353,
ββββ -1.2333, -0.1250],
ββββ [ 0.2759, -3.5196, -0.0355, -0.4968, 0.6988, 0.3293, -1.4655, 1.2332,
ββββ -0.5728, -0.9543],
ββββ [ 0.1292, 0.8452, -1.5253, 4.7322, -0.6898, -0.2914, -0.8749, 0.1460,
ββββ 0.2724, 1.6101],
ββββ [-3.9773, -1.1126, -0.0140, 3.8670, 1.2074, -1.7092, 0.3560, 3.2655,
ββββ -2.3335, -4.1227],
ββββ [-1.3679, -3.1507, -1.1292, 1.1108, -2.8639, 1.0764, -1.2942, -2.1353,
ββββ -1.2333, -0.1250]])
ββββtensor([[-2.6726, -2.1317, -0.5716, 2.4889, -0.8283, -0.3164, -0.4691, 0.5651,
ββββ -1.7834, -2.1239],
ββββ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
ββββ 0.0000, 0.0000],
ββββ [-0.3209, -1.9417, -0.8967, 1.7820, -0.9516, 0.3714, -1.2115, -0.2520,
ββββ -0.5112, 0.1769],
ββββ [-3.9773, -1.1126, -0.0140, 3.8670, 1.2074, -1.7092, 0.3560, 3.2655,
ββββ -2.3335, -4.1227],
ββββ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
ββββ 0.0000, 0.0000]])
message
, propagate
moves messages from nodes to edges.This can be done in two ways:
x_i
is an argument of message
methodx_j
is an argument of message
methodββββtensor([[0, 0, 2, 2, 3, 2],
ββββ [1, 4, 0, 3, 1, 4]])
ββββTrue
message
to x_i
ββββ Inside message method. `x_j.size()` is torch.Size([6, 10])
ββββ Inside update method. `out.size()` is torch.Size([5, 10])
ββββtorch.Size([5, 10])
ββββTrue
Note that this is not one to one - a node can have many edges. As a result it can receive messages from many edges. Thus we need a way to aggregate these messages. This is exactly what the aggr
function does in scatter
.
ββββtensor([[0, 0, 2, 2, 3, 2],
ββββ [1, 4, 0, 3, 1, 4]])
0
:ββββTrue
2
:ββββTrue
scatter
ββββtensor([[ 0.2759, -3.5196, -0.0355, -0.4968, 0.6988, 0.3293, -1.4655, 1.2332,
ββββ -0.5728, -0.9543],
ββββ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
ββββ 0.0000, 0.0000],
ββββ [-0.5084, 2.1698, 2.5749, -1.7214, -2.0338, 0.1578, -0.6327, -0.6667,
ββββ 2.7229, 2.0444],
ββββ [ 0.1292, 0.8452, -1.5253, 4.7322, -0.6898, -0.2914, -0.8749, 0.1460,
ββββ 0.2724, 1.6101],
ββββ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
ββββ 0.0000, 0.0000]])
ββββTrue
pytorch
, gnn