###### tags: `autoglab` `cathay` `GNN` `TGN`
# Temporal Graph Network (TGN)
## Introduction
TGN 是一種具有 **encoder-decoder** 結構的 GNN,會先將 graph 上每個 node/edge feature 轉換成 embedding,再輸入 MLP 進行預測。
傳統 GNN 使用靜態圖 (static graph),靜態圖能反應當下資訊,但無法反應過去歷史,有很大的侷限性。**TGN 為連續時間的動態圖 (continuous-time dynamic graphs),graph 的演化是以 event 為驅動,能反應歷史,在每個時間點,node 都會有相應的狀態 (memory),進而形成隨時間演化的 embedding**。
## Background
有兩種主要的做法來處理動態圖:
1. Discrete-time dynamic graphs (DTDG)
將動態圖轉成多個靜態圖快照的時間序列
2. Continuos-time dynamic graphs (CTDG)
動態圖是由依時序排列的 events 來描述。Events 包含 node addition/deletion、edge addition/deletion、node/edge feature transformations
## Preliminary
1. Graph $\mathcal{G}$ 是由一系列帶有時間戳的 events 組成
\begin{align}
\mathcal{G} = \{x(t_1), x(t_2), ... \}
\end{align}
其中,$0 \leq t_1 \leq t_2 \leq ...$,$x(t)$ 為一個 event。
2. Events 有兩種類型:
* **node-wise event $\mathbf{v}_i(t)$**
若 node $i$ 未出現過,則新增 node $i$,並賦予相應的 feature vector $\mathbf{v}_i(t)$,否則,更新 $\mathbf{v}_i(t)$。
* **interaction event $\mathbf{e}_{ij}(t)$**
Node $i$ 與 node $j$ 之間會形成 (directed) temporal edge。
3. 定義
\begin{align}
\mathcal{V}(T)
&= \{
i: \exists \mathbf{v}_i(t) \in \mathcal{G},
t \in T
\},
\\
\mathcal{E}(T)
&= \{
(i, j): \exists e_{ij}(t) \in \mathcal{G},
t \in T
\},
\\
\mathcal{n}_i(T)
&= \{j: (i,j) \in \mathcal{E}(T)\},
\\
\mathcal{G}(t)
&= (\mathcal{V}[0,t], \mathcal{E}[0,t]).
\end{align}
## Model
TGN 將一個連續時間的動態圖 encode 成各種時間戳下的 embedding
\begin{align}
\mathbf{Z}(t)
= \left(
\mathbf{z}_1(t), \ldots, \mathbf{z}_{n(t)}(t)
\right)
\end{align}
TGN 由以下關鍵 modules 組成:
### Memory
Memory 是 TGN 的核心。Node $i$ 有對應的 state vector $\mathbf{s}_i(t)$,當牽涉到 node $i$ 的 events (node-wise or interaction) 發生時,state $\mathbf{s}_i(t)$ 將會被更新。類似 RNN 的 hidden state,state $\mathbf{s}_i(t)$ 可表示該 node 從古到今的歷史。
當新的 node 產生時,該 node 的 state 會初始化為 $\mathbf{0}$。另外,即使模型在訓練完畢之後,也會不斷更新 state。
### Message Functions
當牽涉到 node $i$ 的 events 發生時,會計算出 messages 好更新 state $\mathbf{s}_i(t)$。
根據 event 的類型,有以下兩種計算 message 的方式:
1. node-wise event
\begin{align}
\mathbf{m}_i(t)
= \text{msg}_n
\left(
s_i(t^-), t, \mathbf{v}_i(t)
\right)
\end{align}
2. interaction event
\begin{align}
\mathbf{m}_i(t)
= \text{msg}_s
\left(
\mathbf{s}_i(t^-), s_j(t^-),
\Delta t, \mathbf{e}_{ij}(t)
\right)
\\
\mathbf{m}_j(t)
= \text{msg}_d
\left(
\mathbf{s}_j(t^-), s_i(t^-),
\Delta t, \mathbf{e}_{ij}(t)
\right)
\end{align}
$\text{msg}_n$、$\text{msg}_s$、$\text{msg}_d$ 均是 learnable。
這篇論文將三個 message functions 簡單設為 concatenation。
### Message Aggregator
一個 batch 會採樣一個 time interval,因此某個 node $i$ 可能會在同個 batch 牽涉到多個 events,這時就要將 node $i$ 產生的多條message $\mathbf{m}_i(t_1), \ldots, \mathbf{m}_i(t_b)$ 濃縮成一個 message。
\begin{align}
\bar{\mathbf{m}}_i(t)
= \text{agg}
\left(
\mathbf{m}_i(t_1),
\ldots,
\mathbf{m}_i(t_b)
\right)
\end{align}
$agg$ 為 aggregation function,論文提出兩種簡單的做法:
1. most recent message:只保留最新的 message
2. mean message:將所有 message 做平均
### Memory Updater
一旦某個 event 與某 node 有關,該 node 的 state 將會被更新:
\begin{align}
\mathbf{s}_i(t)
= \text{mem}
\left(
\bar{\mathbf{m}_i}(t),
\mathbf{s}_i(t^-)
\right)
\end{align}
若 event 為 node-wise,則只更新該 node;若為 interaction event,則兩個 nodes 均被更新。
$\text{mem}$ 為可以是 LSTM 或 GRU 等,此時 input 為 $\bar{\mathbf{m}_i}(t)$,hidden state 為 $\mathbf{s}_i(t^-)$。
### Embedding
Embedding 用來產生 node 在時間 $t$ 的 embedding $z_i(t)$,目的是防止 memory staleness (記憶過時) 問題。即使某 node 長時間不參與任何 event,也能通過其 neighbors 來持續更新 node state。
不失一般性,可以定義 embedding 如下:
\begin{align}
\mathbf{z}_i(t)
= \text{emb}(i,t)
= \sum_{j \in \mathcal{N}_i^k([0,t])}
h
\left(
\mathbf{s}_i(t),
\mathbf{s}_j(t),
\mathbf{e}_{ij},
\mathbf{v}_i(t),
\mathbf{v}_j(t)
\right)
\end{align}
其中,$h$ 是 learnable。有以下幾種實作方式:
1. Identity (id)
直接將 state 當作 embedding。
\begin{align}
\text{emb}(i,t) = \mathbf{s}_i(t)
\end{align}
2. Time projection (time)
令 $\mathbf{w}$ 為一 learnable parameter,定義 embedding 為
\begin{align}
\text{emb}(i,t)
= \left(
\mathbf{1} + \Delta t \mathbf{w}
\right)
\odot \mathbf{s}_i(t)
\end{align}
其中 $\odot$ 為 element-wise product。
3. Temporal Graph Attention (attn)
使用 Transformer 來計算 embedding。若 Transformer 有 $L$ 層,便可捕捉到 $L$-hop temporal neighborhood。
令在時間戳 $t$ 時
* node $i$ 在前一層的 embedding 為 $\mathbf{h}_i^{(l-1)}(t)$
* node $i$ 的 $N$ 個 neighborhood 在前一層的 embedding 為 $\left\{ \mathbf{h}_1^{(l-1)}(t), \ldots, \mathbf{h}_N^{(l-1)}(t) \right\}$
* node $i$ 在時間 $t$ 以前與 neighbor 有 event $\mathbf{e}_{i1}(t_1), ..., \mathbf{e}_{iN}(t_N)$
計算如下
\begin{align}
&\mathbf{h}_i^{(l)}(t)
= \text{MLP}^{(l)}
\left(
\mathbf{h}_i^{(l-1)}(t) \|
\tilde{h}_i^{(l)}(t)
\right),
\\
&\tilde{\mathbf{h}}_i^{(l)}(t)
= \text{MultiheadAttention}^{(l)}
\left(
\mathbf{q}^{(l)}(t),
\mathbf{K}^{(l)}(t),
\mathbf{V}^{(l)}(t)
\right),
\\
&\mathbf{q}^{(l)}(t)
= \mathbf{h}_i^{(l-1)}(t) \| \phi(0),
\\
&\mathbf{K}^{(l)}(t)
= \mathbf{V}^{(l)}(t)
= \mathbf{C}^{(l)}(t),
\\
&\mathbf{C}^{(l)}(t)
= \left[
\mathbf{h}_1^{(l-1)}(t) \|
\mathbf{e}_{i1}(t_1) \|
\phi(t-t_1),
\ldots,
\mathbf{h}_N^{(l-1)}(t) \|
\mathbf{e}_{iN}(t_1) \|
\phi(t-t_N)
\right],
\\
&\mathbf{z}_i(t) = \mathbf{h}^{(L)}_i(t)
\end{align}
其中,$\phi(\cdot)$ 為 time encoding,$\mathbf{h}_j^{(0)}(t)= \mathbf{s}_j(t) + \mathbf{v}_j(t)$。
4. Temporal Graph Sum (sum)
\begin{align}
&\mathbf{h}^{(l)}_i(t)
= \mathbf{W}^{(l)}_2
\left(
\mathbf{h}^{(l-1)}_i(t) \|
\tilde{\mathbf{h}}^{(l)}_i(t)
\right)
\\
&\tilde{\mathbf{h}}^{(l)}_i(t)
= \text{ReLU}
\left(
\sum_{j \in \mathcal{N}_i[0, t]}
\mathbf{W}^{(l)}_1
\left(
\mathbf{h}^{(l-1)}_j
\|
\mathbf{e}_{ij}
\|
\phi(t-t_j)
\right)
\right)
\\
&\mathbf{z}_i(t) = \mathbf{h}^{(L)}_i(t)
\end{align}
## Training
TGN 可以應用於 node classification (semi-supervised) 或 edge predition (self-supervised) 等,以下以後者做說明。
<center>
<img src=https://i.imgur.com/1UE6ncF.png
width=500>
<br>
<br>
</center>
1. 依照時間順序來採樣 batch,具體來說,下一個 batch 的 time interval 比前一個 batch 晚
2. 使用前一個 batch 的資料 (raw messages) 計算出當前的 message,避免資訊洩漏問題。
以上圖為例,batch 有包含在時間戳 $t_5$、$t_6$ 的 2 個 events,牽涉到 node 1, 2, 3。在做 forward propagation 之前,會從 **Raw Message Store** 撈取 node 1, 2, 3 之前 ($t_1$、$t_2$、$t_3$、$t_4$) 的 raw messages。Raw Message Store 會存放 message function 的 inputs,包括 $\mathbf{v}_i(t)$、$\mathbf{s}_i(t^-)$、$\mathbf{s}_j(t^-)$、$\mathbf{e}_{ij}(t)$ 等。
當前 ($t_5$、$t_6$) 的 raw messages 亦被儲存至 Raw Message Store,供下一個 batch 使用。
3. 依循公式並更新 states,以及計算出 node embeddings。
4. Decoder 將 node embeddings pairs 轉換成 edge probabilities,預測當前兩兩 nodes 之間是否存在 edge,最終計算 loss 並更新 TGN。
5. Batch 不宜太大,否則會導致 batch 裡較晚的資料從 Raw Message Store 拿到過時的 raw messages。通常 batch 大小設為 $200$ 能在記憶過時與運算速度間取得好的 trade-off。

## Results
## Reference
1. https://arxiv.org/abs/2006.10637
2. https://github.com/twitter-research/tgn
3. https://github.com/pyg-team/pytorch_geometric
4. https://towardsdatascience.com/temporal-graph-networks-ab8f327f2efe
## Programming
1. **NeighborFinder 類**
+ `find_before(self, src_idx, cut_time)`:找出 `cut_time` 以前所有 `src_idx` 對應的 node 有關的事件,返回
+ `node_to_neighbors`
+ `node_to_edge_idxs`
+ `node_to_edge_timestamps`
+ `get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors)`:找出 `timestamps` 以前所有 `source_nodes` 對應的事件,最後,每個 node 會挑選出 `n_neighbors` 個事件,有兩種挑選方法:
+ uniform
+ most recent interactions
返回`neighbors`, `edge_idxs`, `edge_times`。
2. **RandEdgeSampler 類**
3. **LastMessageAggregator** 類
+ `aggregate(self, node_ids, messages)`:先通過 `node_ids` 找出 `unique_node_ids`,通過 `messages` 找出沒有 message 的 node 並過濾,最後找出 `unique_node_ids` 對應的 last message,返回 `to_update_node_ids`, `unique_messages`, `unique_timestamps`。
4. **RNNMemoryUpdater 類**
+ `update_memory(self, unique_node_ids, unique_messages, timestamps)`:更新 `self.memory` 的 `last_update` 為 `timestamps`,並使用 `self.memory_updater` 來更新 `self.memory`。
5. **TemporalAttentionLayer 類**
+ `forward(self, src_node_features, src_time_features, neighbors_features,
neighbors_time_features, edge_features, neighbors_padding_mask)`:如上公式所述,會計算出新的 embedding,最後使用 merger,會進行 skip connection,返回 `attn_output`, `attn_output_weights`