###### tags: `autoglab` `cathay` `GNN` `TGN` # Temporal Graph Network (TGN) ## Introduction TGN 是一種具有 **encoder-decoder** 結構的 GNN,會先將 graph 上每個 node/edge feature 轉換成 embedding,再輸入 MLP 進行預測。 傳統 GNN 使用靜態圖 (static graph),靜態圖能反應當下資訊,但無法反應過去歷史,有很大的侷限性。**TGN 為連續時間的動態圖 (continuous-time dynamic graphs),graph 的演化是以 event 為驅動,能反應歷史,在每個時間點,node 都會有相應的狀態 (memory),進而形成隨時間演化的 embedding**。 ## Background 有兩種主要的做法來處理動態圖: 1. Discrete-time dynamic graphs (DTDG) 將動態圖轉成多個靜態圖快照的時間序列 2. Continuos-time dynamic graphs (CTDG) 動態圖是由依時序排列的 events 來描述。Events 包含 node addition/deletion、edge addition/deletion、node/edge feature transformations ## Preliminary 1. Graph $\mathcal{G}$ 是由一系列帶有時間戳的 events 組成 \begin{align} \mathcal{G} = \{x(t_1), x(t_2), ... \} \end{align} 其中,$0 \leq t_1 \leq t_2 \leq ...$,$x(t)$ 為一個 event。 2. Events 有兩種類型: * **node-wise event $\mathbf{v}_i(t)$** 若 node $i$ 未出現過,則新增 node $i$,並賦予相應的 feature vector $\mathbf{v}_i(t)$,否則,更新 $\mathbf{v}_i(t)$。 * **interaction event $\mathbf{e}_{ij}(t)$** Node $i$ 與 node $j$ 之間會形成 (directed) temporal edge。 3. 定義 \begin{align} \mathcal{V}(T) &= \{ i: \exists \mathbf{v}_i(t) \in \mathcal{G}, t \in T \}, \\ \mathcal{E}(T) &= \{ (i, j): \exists e_{ij}(t) \in \mathcal{G}, t \in T \}, \\ \mathcal{n}_i(T) &= \{j: (i,j) \in \mathcal{E}(T)\}, \\ \mathcal{G}(t) &= (\mathcal{V}[0,t], \mathcal{E}[0,t]). \end{align} ## Model TGN 將一個連續時間的動態圖 encode 成各種時間戳下的 embedding \begin{align} \mathbf{Z}(t) = \left( \mathbf{z}_1(t), \ldots, \mathbf{z}_{n(t)}(t) \right) \end{align} TGN 由以下關鍵 modules 組成: ### Memory Memory 是 TGN 的核心。Node $i$ 有對應的 state vector $\mathbf{s}_i(t)$,當牽涉到 node $i$ 的 events (node-wise or interaction) 發生時,state $\mathbf{s}_i(t)$ 將會被更新。類似 RNN 的 hidden state,state $\mathbf{s}_i(t)$ 可表示該 node 從古到今的歷史。 當新的 node 產生時,該 node 的 state 會初始化為 $\mathbf{0}$。另外,即使模型在訓練完畢之後,也會不斷更新 state。 ### Message Functions 當牽涉到 node $i$ 的 events 發生時,會計算出 messages 好更新 state $\mathbf{s}_i(t)$。 根據 event 的類型,有以下兩種計算 message 的方式: 1. node-wise event \begin{align} \mathbf{m}_i(t) = \text{msg}_n \left( s_i(t^-), t, \mathbf{v}_i(t) \right) \end{align} 2. interaction event \begin{align} \mathbf{m}_i(t) = \text{msg}_s \left( \mathbf{s}_i(t^-), s_j(t^-), \Delta t, \mathbf{e}_{ij}(t) \right) \\ \mathbf{m}_j(t) = \text{msg}_d \left( \mathbf{s}_j(t^-), s_i(t^-), \Delta t, \mathbf{e}_{ij}(t) \right) \end{align} $\text{msg}_n$、$\text{msg}_s$、$\text{msg}_d$ 均是 learnable。 這篇論文將三個 message functions 簡單設為 concatenation。 ### Message Aggregator 一個 batch 會採樣一個 time interval,因此某個 node $i$ 可能會在同個 batch 牽涉到多個 events,這時就要將 node $i$ 產生的多條message $\mathbf{m}_i(t_1), \ldots, \mathbf{m}_i(t_b)$ 濃縮成一個 message。 \begin{align} \bar{\mathbf{m}}_i(t) = \text{agg} \left( \mathbf{m}_i(t_1), \ldots, \mathbf{m}_i(t_b) \right) \end{align} $agg$ 為 aggregation function,論文提出兩種簡單的做法: 1. most recent message:只保留最新的 message 2. mean message:將所有 message 做平均 ### Memory Updater 一旦某個 event 與某 node 有關,該 node 的 state 將會被更新: \begin{align} \mathbf{s}_i(t) = \text{mem} \left( \bar{\mathbf{m}_i}(t), \mathbf{s}_i(t^-) \right) \end{align} 若 event 為 node-wise,則只更新該 node;若為 interaction event,則兩個 nodes 均被更新。 $\text{mem}$ 為可以是 LSTM 或 GRU 等,此時 input 為 $\bar{\mathbf{m}_i}(t)$,hidden state 為 $\mathbf{s}_i(t^-)$。 ### Embedding Embedding 用來產生 node 在時間 $t$ 的 embedding $z_i(t)$,目的是防止 memory staleness (記憶過時) 問題。即使某 node 長時間不參與任何 event,也能通過其 neighbors 來持續更新 node state。 不失一般性,可以定義 embedding 如下: \begin{align} \mathbf{z}_i(t) = \text{emb}(i,t) = \sum_{j \in \mathcal{N}_i^k([0,t])} h \left( \mathbf{s}_i(t), \mathbf{s}_j(t), \mathbf{e}_{ij}, \mathbf{v}_i(t), \mathbf{v}_j(t) \right) \end{align} 其中,$h$ 是 learnable。有以下幾種實作方式: 1. Identity (id) 直接將 state 當作 embedding。 \begin{align} \text{emb}(i,t) = \mathbf{s}_i(t) \end{align} 2. Time projection (time) 令 $\mathbf{w}$ 為一 learnable parameter,定義 embedding 為 \begin{align} \text{emb}(i,t) = \left( \mathbf{1} + \Delta t \mathbf{w} \right) \odot \mathbf{s}_i(t) \end{align} 其中 $\odot$ 為 element-wise product。 3. Temporal Graph Attention (attn) 使用 Transformer 來計算 embedding。若 Transformer 有 $L$ 層,便可捕捉到 $L$-hop temporal neighborhood。 令在時間戳 $t$ 時 * node $i$ 在前一層的 embedding 為 $\mathbf{h}_i^{(l-1)}(t)$ * node $i$ 的 $N$ 個 neighborhood 在前一層的 embedding 為 $\left\{ \mathbf{h}_1^{(l-1)}(t), \ldots, \mathbf{h}_N^{(l-1)}(t) \right\}$ * node $i$ 在時間 $t$ 以前與 neighbor 有 event $\mathbf{e}_{i1}(t_1), ..., \mathbf{e}_{iN}(t_N)$ 計算如下 \begin{align} &\mathbf{h}_i^{(l)}(t) = \text{MLP}^{(l)} \left( \mathbf{h}_i^{(l-1)}(t) \| \tilde{h}_i^{(l)}(t) \right), \\ &\tilde{\mathbf{h}}_i^{(l)}(t) = \text{MultiheadAttention}^{(l)} \left( \mathbf{q}^{(l)}(t), \mathbf{K}^{(l)}(t), \mathbf{V}^{(l)}(t) \right), \\ &\mathbf{q}^{(l)}(t) = \mathbf{h}_i^{(l-1)}(t) \| \phi(0), \\ &\mathbf{K}^{(l)}(t) = \mathbf{V}^{(l)}(t) = \mathbf{C}^{(l)}(t), \\ &\mathbf{C}^{(l)}(t) = \left[ \mathbf{h}_1^{(l-1)}(t) \| \mathbf{e}_{i1}(t_1) \| \phi(t-t_1), \ldots, \mathbf{h}_N^{(l-1)}(t) \| \mathbf{e}_{iN}(t_1) \| \phi(t-t_N) \right], \\ &\mathbf{z}_i(t) = \mathbf{h}^{(L)}_i(t) \end{align} 其中,$\phi(\cdot)$ 為 time encoding,$\mathbf{h}_j^{(0)}(t)= \mathbf{s}_j(t) + \mathbf{v}_j(t)$。 4. Temporal Graph Sum (sum) \begin{align} &\mathbf{h}^{(l)}_i(t) = \mathbf{W}^{(l)}_2 \left( \mathbf{h}^{(l-1)}_i(t) \| \tilde{\mathbf{h}}^{(l)}_i(t) \right) \\ &\tilde{\mathbf{h}}^{(l)}_i(t) = \text{ReLU} \left( \sum_{j \in \mathcal{N}_i[0, t]} \mathbf{W}^{(l)}_1 \left( \mathbf{h}^{(l-1)}_j \| \mathbf{e}_{ij} \| \phi(t-t_j) \right) \right) \\ &\mathbf{z}_i(t) = \mathbf{h}^{(L)}_i(t) \end{align} ## Training TGN 可以應用於 node classification (semi-supervised) 或 edge predition (self-supervised) 等,以下以後者做說明。 <center> <img src=https://i.imgur.com/1UE6ncF.png width=500> <br> <br> </center> 1. 依照時間順序來採樣 batch,具體來說,下一個 batch 的 time interval 比前一個 batch 晚 2. 使用前一個 batch 的資料 (raw messages) 計算出當前的 message,避免資訊洩漏問題。 以上圖為例,batch 有包含在時間戳 $t_5$、$t_6$ 的 2 個 events,牽涉到 node 1, 2, 3。在做 forward propagation 之前,會從 **Raw Message Store** 撈取 node 1, 2, 3 之前 ($t_1$、$t_2$、$t_3$、$t_4$) 的 raw messages。Raw Message Store 會存放 message function 的 inputs,包括 $\mathbf{v}_i(t)$、$\mathbf{s}_i(t^-)$、$\mathbf{s}_j(t^-)$、$\mathbf{e}_{ij}(t)$ 等。 當前 ($t_5$、$t_6$) 的 raw messages 亦被儲存至 Raw Message Store,供下一個 batch 使用。 3. 依循公式並更新 states,以及計算出 node embeddings。 4. Decoder 將 node embeddings pairs 轉換成 edge probabilities,預測當前兩兩 nodes 之間是否存在 edge,最終計算 loss 並更新 TGN。 5. Batch 不宜太大,否則會導致 batch 裡較晚的資料從 Raw Message Store 拿到過時的 raw messages。通常 batch 大小設為 $200$ 能在記憶過時與運算速度間取得好的 trade-off。 ![](https://i.imgur.com/Ed4LZPd.png) ## Results ## Reference 1. https://arxiv.org/abs/2006.10637 2. https://github.com/twitter-research/tgn 3. https://github.com/pyg-team/pytorch_geometric 4. https://towardsdatascience.com/temporal-graph-networks-ab8f327f2efe ## Programming 1. **NeighborFinder 類** + `find_before(self, src_idx, cut_time)`:找出 `cut_time` 以前所有 `src_idx` 對應的 node 有關的事件,返回 + `node_to_neighbors` + `node_to_edge_idxs` + `node_to_edge_timestamps` + `get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors)`:找出 `timestamps` 以前所有 `source_nodes` 對應的事件,最後,每個 node 會挑選出 `n_neighbors` 個事件,有兩種挑選方法: + uniform + most recent interactions 返回`neighbors`, `edge_idxs`, `edge_times`。 2. **RandEdgeSampler 類** 3. **LastMessageAggregator** 類 + `aggregate(self, node_ids, messages)`:先通過 `node_ids` 找出 `unique_node_ids`,通過 `messages` 找出沒有 message 的 node 並過濾,最後找出 `unique_node_ids` 對應的 last message,返回 `to_update_node_ids`, `unique_messages`, `unique_timestamps`。 4. **RNNMemoryUpdater 類** + `update_memory(self, unique_node_ids, unique_messages, timestamps)`:更新 `self.memory` 的 `last_update` 為 `timestamps`,並使用 `self.memory_updater` 來更新 `self.memory`。 5. **TemporalAttentionLayer 類** + `forward(self, src_node_features, src_time_features, neighbors_features, neighbors_time_features, edge_features, neighbors_padding_mask)`:如上公式所述,會計算出新的 embedding,最後使用 merger,會進行 skip connection,返回 `attn_output`, `attn_output_weights`