# [All in One: Multi-Task Prompting for Graph Neural Networks](https://arxiv.org/abs/2307.01504) ### 作者 - Xiangguo Sun - Hong Cheng - Jia Li - Bo Liu - Jihong Guan ### KEYWORDS pre-training; prompt tuning; graph neural networks ### 名詞解釋 Multi-task Prompting:透過prompting learning來在不需要透過微調或調整架構訓練的情況下延伸道不同任務場景 a language prompt:在input text後面再加上一段文字 ## ABSTRACT - pre-train與fine-tuning目前已經成為了graph類的任務標準工作(或研究)流程 - 而這一類研究過程成功地補足許多應用程式在graph類註釋不足的狀況 - 然而node level, edge level, and graph level不同層級的任務分期度高,用同樣的模型去訓練全部不只成效不佳、甚至會互相影響降低彼此的成效在訓練(導致negative transfer) - 藉著NLP(也就是自然語言處理)類型在LLM這類模型和prompting learning上運用先驗知識來輔助不同類型的NLP任務上的成功啟發,研究者試著研究利用prompting輔助graphs類的任務來填補pre-train模型與多樣化的graph任務 - 這篇paper提出新的multi-task prompting(多工提示)方法來訓練graph model - 首先他們在`prompt token,token structure和inserting pattern` 上``統一``了graph prompt與language prompt的格式,藉此使得graph與natural language的prompting兩個不同結構可以共同協作 - 為了更進一步縮短graph任務與作為sota(state-of-the-art)的pre-train策略,研究者進一步研究了graph應用的任務空間(不太懂,推測是指可以處理的範疇所在);同時也將下游任務統整為graph-level - 研究者提出``meta-learning(元學習)``以更有效的初始化學習multi-task prompt過程,其提示的架構(framework)也會更加可靠且模型能力也會更加廣泛 - 後面的實驗進一步證明了這一點 ## 1 INTRODUCTION - [GNNs(Graph neural networks )](https://en.wikipedia.org/wiki/Graph_neural_network)目前被廣泛應用於像是社交網路、異常值檢測或是網路分析等任務上 - 除此之外,新的研究趨勢是將GNN使用在特定任務上 - 傳統的監督學習很仰賴graph的label且對於現實世界的應用上明顯不足 - 另一個短版是對於訓練樣本以外的資料會出現Overfitting - 為了應對以上問題,大多都以pretrain和fine-tuning來應對,即在現有資料(或方便取得的資料上)進行預訓練之後再在新領域上微調先前育訓練模型的最後一層 - 雖然前面提及看似有很好的成效(預訓練可以映射相連節點到潛在空間中並且使得其距離縮短)然而下游任務不只有binary edge prediction,還有其他資料結構層級像是 node-level (e.g., node multi-class classification)、graph-level tasks (e.g., graph classification);轉移前面的模型到這類任務(像是multi-class node classification)上,會需要在更高維度的參數中搜尋額外的node分類標籤,這種tuning可能會在連結不同節點過程中導致negative transfer - 因此可以說這類型的Tuning不只是相對沒效且還要花費額外資源去學習將node的embedding映射到graph上 ### 前面是在介紹GNN與其短版,在不同的資料結構上成效不佳;同時間也對應到前者所說的將不同的資料型態、任務等以統一規格輸入 - 上述問題的對應方法則是從“pretraining and fine-tuning”擴展到“pre-training, prompting, and fine-tuning” - Prompt learning在NLP上被證實有用且可以使自然語言的預訓練模型上被廣泛應用到不同領域的文字任務上 - 具體來說,a language prompt是在input text後面再加上一段文字 - 舉例而言,在“KDD2023 will witness many high-quality papers.I feel so [MASK]”中如果沒有進一步調整參數的話,[MASK]基於預訓練的資料內的知識,會更高機率地被預測為“excited”而不是“upset” - 藉此,下游任務的目標可以自然地因為預訓練時所給予的目標和知識`對齊`(應該是指說可以正確地使問題對應到答案) - 由前者language prompt的成功啟發,研究者希望可以利用這一點在graph上 - 圖一可以看到prompt tuning在graph這個領域上需要一些輕量的prompt(中文的話可以翻譯成提示)、右邊是在凍結參數的情況下以prompt去重構下游任務成預訓練任務(老實說沒做過我也不知道是怎麼搞)使得下游任務可以高效地微調甚至不須微調即可應用 - 其中few-shot是最有效的 ![image](https://hackmd.io/_uploads/SkxNGz7Xkl.png) ### 這邊在說把language prompt成功和想要應用到graph的任務上跟圖示 - 然而實際上的問題是graph的prompt設計比自然語言的prompt設計更加棘手 - 首先,是自然語言的prompt本身通常是在input texts後面加上可學習的向量或預設的短句,參照圖二,上方自然語言的prompt只需要考慮上下文即可;而graph的prompt則不只需要參考prompt的上下文還需要知道`如何組織prompt的token以及將這些token插入進原始的graph中` - 這兩者都是未被定義的問題 ![image](https://hackmd.io/_uploads/B1sytXmXkx.png) - 再者,如何將下游任務調和成預訓練任務也是難題。在NLP領域中通常會把針對語言模型進行遮罩訓練在轉移其應用到各類任務上(舉例question answering [22], sentiment classification [17]. The underlying support [21]);這類任務子空間上常常有大量的overlapping(轉換成白話文的話就是在不同的任務中,文句所使用的問答詞語、文法上常常有語言結構上的相似性或語句或詞語重疊);然而,當前並有研究知道在graph上是否有類似自然語言這樣子的文法或詞句類型上的特性 - 如何適當地決定預訓練任務並且制定下游任務以增進模型處理更廣泛任務的能力相當地困難 - 目前研究graph prompt的研究相當地少且只能使用特定pretext(不知道怎麼翻譯)來處理單一類型任務(像是node classification) - 以上種種離多工處理不同層級的任務還有很長一段路 - 最後,要找到好用的prompt並非易事,且對多工任務的設定上的prompt initialization很敏感 - [14, 38]就是在NLP上探討手工的內容或某些任務相關的離散特徵來進行prompt initialization - 依然不足以應對新型的任務 ### 以上prompt設計棘手和無法應對新類型的任務的問題可能比多工graph領域的其他問題更麻煩,畢竟graph作為特徵的任務涉及很廣泛的領域(如果沒辦法讓模型學會"應變"不同類型任務會很麻煩) ### 1.Final Present Work 1. 為了處理第一個問題,研究者將language prompt與graph prompt統一格式(為了應用NLP領域的思維在graph上),且從prompt tokens,token structures,prompt inserting patterns上設計了graph的 prompt 2. 為了處理第二個問題,研究者研究了graph的任務子空間後透過[induced graphs](https://zh.wikipedia.org/zh-tw/%E5%AF%BC%E5%87%BA%E5%AD%90%E5%9B%BE)將node-level and edge-level tasks重構成graph-level 3. 處理第三個問題,他們介紹了meta-learning的寄宿來解決multiple tasks以更好地從prompt中學習 - 經過審慎評估後,實驗結果證明有效。其優勢為: - 統一格式(Sectuion 3.3) - 有效重構node-level and edge-level tasks to graph-level tasks以連結到 pre-training pretexts (section 3.2) - 引入meta-learning技術以從好用的prompt中學習增進多工的表現 (section 3.4). - 仔細分析自身方法(section 3.5)且確認其有效性 (section 4). ## 2 BACKGROUND 1.Graph Neural Networks - GNNs目前在許多graph的應用上有很強的表達能力 - 舉例來說有效的神經網路架構有graph attention Network (GAT) [32], graph convolution network (GCN) [34], Graph Transformer [25] - 而近來的研究也有在試圖提升graph learning的適應性以及更有效地transfer到其他新領域上的,使得很多研究偏向預訓練而非傳統的監督式學習 2. Graph Pre-training - Graph上的預訓練意圖降低人工標註新任務的成本 - 有效的預訓練策略包含GCA[40], edge-level pretext like edge prediction [13], and graph-level contrastive learning such as GraphCL [36] and SimGRACE [35]. - 其中GraphCL最小化同一graph不同數據增強(augmentations)下生成的graph-level representation之間的距離(在同一個graph使用不同的augmentation下轉換之後的兩個新的graph之間的距離會被最小化以實現對比) - 而SimGRACE試著擾亂graph模型的參數空間並縮小同一個graph在不同擾亂(perturbation)之後多者之間的距離 - 以上方法皆被證實比graph knowledge learning 有效且為本篇paper預設的預訓練策略 3. Prompt Learning & Motivations - 直覺地說,graph-level的預訓練與自然語言的遮罩訓練(像BERT訓練時會把語句中某個字換成[MASK])有著相似之處:對其兩個graph views(其可能是由node/edge/feature mask or other perturbations產生)和預測語句中的空白非常地相似 - 那為什麼不用相似的prompt format for graph來強化GNN的generalization能力? - 比起透過adaptive task head來對預訓練模型進行微調,prompt learning透過重新調整輸入數據的表達方式以更符合"預訓練"的目的 - NLP內現在已經有許多被證實有效的prompt技術(像是hand-crafted prompts like GPT-3 [3], discrete prompts like [7, 26], and trainable prompts in the continuous spaces like [16, 19]),但卻少有應用在graph上的prompt研究;研究者只有找到GPPT試圖設計給graph的prompts,很不幸的是內容非常侷限且對多工的需求上略顯不足 ## 3 MULTI-TASK PROMPTING ON GRAPHS <!-- 我等等來看看 3.1 ~ 3.3 的東西,如果哪裡有錯的話可以隨時跟我說 by 笙宏 --> OKay! ### 3.1 Overview of Our Framework - 目的 - 在於將 prompt graph 與原本的 input graph 連接上,藉此有效的縮短 pre-training strategy 與 downstream tasks 之間在 task space 之間的差距,進而減緩將從 pre-training 學到的東西遷移到其他 domain 或 task 的難度。 - 概要 1. 將 downstream tasks 經由訂定的公式將其轉換成對應的 graph-level tasks 2. 透過 prompt graph(with learnable tokens, inner structures, and adaptive inserting patterns) 來縮短不同 level 的 task 在 space 中的差距 3. 建立一個 meta-learning 流程,藉此學習一個通用度更高的 graph prompt 來應用在多種任務的情境下 ### 3.2 Reformulating Downstream Tasks ![image](https://hackmd.io/_uploads/SyMDmr7Xkl.png) #### 3.2.1 為甚麼要做 Reformulating - 為了解決前面提到的不同 level 的 task 在 task spaces 中的差距問題,透過 reformulating 將不同的 task 轉變為更 general 的形式,藉此縮短他們的差距 - 在語言中不同的任務領域相對關係有著蠻大的合集(好比說預測某個句子的空缺詞語也可以用問答的方式表達),而graph上不同任務領域之間的對應關係就相對複雜了,因此需要藉由重構來將任務變成更為通用的形式 - 舉例來說edge-level task跟the node-level task要用同樣地方式進行處理蠻牽強的畢竟資料型態上操作就有蠻大的不同了,這也限制了模型的表現,甚至會有負遷移(也就是越學越笨的意思) - 同樣的問題也出現在研究者的“pre-training,prompting,and fine-tuning”架構上,因此需要進一步藉由轉換成更通用的資料形式來縮短不同資料結構層級上的差距 #### 3.2.2 為甚麼是轉成 graph-level - 有了前面的動機之後,研究者找到了如上圖3b不同任務領域之間的關係圖(上圖可見任務之間是上下包住彼此的) - node-level,edge-level,graph-level三者之間最通用的資料型態是graph(像是node-level的操作有改變/新增/刪除 節點特徵;edge-level有新增/刪除 edge的操作都可以以graph的形式進行處理,像是graph中如果要刪除內部比較小的subgraph,也可以視為刪除節點或邊) - 這篇 Paper 發現,在 node-level 以及 edge-level 的操作在 graph-level 上都有類似的操作 - node-level: 新增 / 刪除 node - edge-level: 新增 / 刪除 edge - graph-level: 刪除 subgraph => 刪除多個 node 以及 edge - 從上述舉例可發現在 graph-level 上所做的操作可以同時涵蓋到 node 以及 edge level 上的操作 - 透過上述的概念可想像出 Figure 3(b) 的樣子,發現 graph-level tasks 在空間上具有很大的涵蓋率,因此也更適合將 node-level 或 edge-level 的 task 轉換到 graph-level task #### 3.2.3如何重構成下游任務 ![image](https://hackmd.io/_uploads/SJpcXBQ7kx.png) - 藉由 𝜏-ego network建構,subgraph會透過連結鄰近節點保存原先的node結構和語意上下文 - 轉換之後,像node分類任務也可以轉換為graph分類任務 - 有無成對的node可以當作是為正或為負的edge連結兩者 - 對於沒有不同權重的graph,the 𝜏 distance等同𝜏-hop length - 而有不同權重的graph,the 𝜏 distance則可以視為最短路徑的長度 - the induced graph 可以用許多有效的演算法計算 - 首先定義一個距離參數 $\tau$ - node-level -> graph-level - 針對要被操作的 node,以此 node 為中心,將在 $\tau$ 範圍內的所有 node 納入一個 subgraph - edge-level -> graph-level - 針對要被操作的 edge,以 edge 兩端的 node 為中心,將在 $\tau$ 範圍內的所有 node 納入一個 subgraph - 如果 edge 上有權重的話,則根據權重來判斷 node 之間的距離 - 反之則看 node 之間有幾條 edge 連接來判斷距離 - 或是也可以直接當作每條 edge 上的權重都是 1,這樣出來的結果應該是相等的 ### 3.3 Prompt Graph Design #### 3.3.1 建構 prompt graph 時需要以下三個元素 1. prompt token 包含了有著與輸入文字或節點相同大小轉換成向量的prompting資訊 2. token structure 指的是不同tokens之間連結的方式,以NLP的語句中舉例的話;其語句就是呈現線性的tokens;graph內的tokens則可能是非線性、比NLP類的連結更加複雜 3. inserting pattern 為如何將prompt加到input data上,NLP上即是將prompt用字串相加的方式加在input data後面就好;graph上則是沒有明確要加在哪,因此要決定加在哪更加困難 接著來定義一下 input graph 與 prompt graph 的表示法 - input graph - graph 為 $\mathcal{G=(V, E)}$ - graph 中的 N 個 node 為 $\mathcal{V}=\{\mathcal{v_1, v_2, ..., v_\mathrm{N}}\}$,每個 node $\mathcal{v}_i$ 都會有一個對應的 feature vector $x_i \in \mathbb{R}^{1 \times d}$ - graph 中的 edges 則為 $\mathcal{E = \{(v_i, v_j) \ | \ v_i, v_j \in V\}}$ - prompt graph - garph 為 $\mathcal{G_p=(P, S)}$ - graph 中的 P 個 token 為 $\mathcal{P}=\{\mathcal{p_1, p_2, ..., p_{|P|}}\}$,每個 token $\mathcal{p}_i$ 都會有一個對應的 token vector $\mathrm{p}_i \in \mathbb{R}^{1 \times d}$ - graph 中的 edges 則為 $\mathcal{S = \{(p_i,p_j) \ | \ p_i, p_j \in P\}}$ - 一般來說,token總個數會遠小於隱藏層的數量 #### 3.3.2 Prompt token - prompt token 的數量具有以下限制 - 需小於 input graph 中 node 的數量 - 需小於 pre-trained model 中 hidden layer 的數量 - 在確定好 prompt token 的內容後,將其 vector 嵌入至原本 input graph 的特定幾個 node 的 feature vector 當中 - 假設我們要將第 j 個 token 的 vector 嵌入至第 i 個 node,則可以表示為 - $\hat{x_i} = x_i + \mathrm{p}_j$ - 嵌入完後再將植入完 prompt 的 feature vector 送入 pre-trained model 做進一步地處理 #### 3.3.3 Token structure - 與NLP prompt不同的是,prompt graph是implicit 研究者提出設計三種方式去表示prompt token structure 1. 建立一個 tunable parameter,這個參數是一個集合,代表各個 token 之間有 edge 連接的可能性 $$ \mathcal{A} = \bigcup^{|\mathcal{P}| - 1}_{ \substack{ i=1 \\ j=i+1 } } \{a_{ij}\} $$ 2. 拿兩個 token 的 vector 去算內積,得出的結果再帶入 sigmoid function,如果有超過定義的門檻值的話,則代表兩個 token 之間有連結,反之則沒有 3. 當S為空集合時直接假設tokens彼此獨立不連結 #### 3.3.4 Inserting pattern 定義 $\psi$為將prompt graph新增到input graph 的inserting function,會得到圖 $\mathcal{G_m = \psi(G, G_p)}$,其定義為prompt tokens和input graph nodes的內積後再進行對應到權重的計算,像是 $$ \hat{x}_i = x_i + \sum_{k=1}^{|\mathcal{P}|}w_{ik}p_k $$ 其中$w_ik$為修剪多餘連結的權重值 $$ w_{ik} = \begin{cases} \sigma(p_k \cdot x_i^T), & \text{if } \sigma(p_k \cdot x_i^T) > \delta \\ 0, & \text{otherwise} \end{cases} $$ 在這個 function 中,會做以下計算 - 計算 $\mathcal{G}$ 中每個節點的與 prompt graph token 的關聯程度與權重,並判斷該關聯程度是否超過給定的門檻值,有則權重等同於該關聯程度,反之則為 0 - 將每個節點的 feacture vector 加上權重乘以對應 token 的 vector ### 3.4 Multi-task Prompting via Meta Learning #### 3.4.1 Constructiog meta prompting tasks - $\tau_i$作為第i個任務與其相關的資料$\mathcal{D}^s_{\tau_i}$和提問資料$\mathcal{D}^q_{\tau_i}$ - 以圖分類任務為例, $\mathcal{D}^s_{\tau_i}$,$\mathcal{D}^q_{\tau_i}$內包含了標註過(也就是分類標籤)的graph - 以節點分類任務為例,從各個節點生成的induce graph,會將graph label對齊到node label,其graph會作為$\mathcal{D}^s_{\tau_i}$和$\mathcal{D}^q_{\tau_i}$ - 以邊分類任務來說,從edge induced graph,其edge label最多會當成兩個端點 #### 3.4.2 Applying Meta-learning to Graph Prompting - $\theta$為prompt參數 - $\pi^*$為預訓練時背後(backbone應該翻譯成骨幹?)的固定參數 - $\phi$則是tasker(依照前文推斷處理任務給出output的部分) - $f_{\theta,\phi|\pi^*}$為prompt graph的pipeline($\theta$)、pre-trained model($\pi^*$,fixed)、下游任務處理器(downstream tasker)($\phi$)的標註 - $\mathcal{L}_D(f)$為pipeline $f$在資料$\mathcal{D}$的task loss (loss function),$\alpha$為學習率 - 參數更新的公式如下 $$ \theta_i^k = \theta_i^{k-1} - \alpha \nabla_{\theta_i^{k-1}} \mathcal{L}_{\mathcal{D}_{\tau_i}^S} \left( f_{\theta_i^{k-1}, \phi_i^{k-1}} | \pi^* \right) $$ $$ \phi_i^k = \phi_i^{k-1} - \alpha \nabla_{\phi_i^{k-1}} \mathcal{L}_{\mathcal{D}_{\tau_i}^S} \left( f_{\theta_i^{k-1}, \phi_i^{k-1}} | \pi^* \right) $$ - 最一開始的參數設定如下 $$\theta^0_i = \theta$$ $$\phi^0_i = \phi$$ - 這部分的目標是有效地學到初始為meta-prompting tasks設定的($\theta$,$\phi$) - 要學到最終的目標為 $$ \theta^*, \phi^* = \arg\min_{\theta, \phi} \sum_{\tau_i \in \mathcal{T}} \mathcal{L}_{\mathcal{D}_{\tau_i}^q} \left( f_{\theta_i, \phi_i} | \pi^* \right) $$ - 如果有上過機器學習理論的同學可能會發現這邊跟機器學習中的線性回歸的更新計算方式、最終目標很像 - 裡面的$\mathcal{T}$為Task也就是訓練資料中的任務 - 進一步推導的話最終結果的$\theta$可以如下表示 $$ \begin{align*} \theta &\leftarrow \theta - \beta \cdot g_{\theta}^{\text{second}} \\ &= \theta - \beta \cdot \sum_{\tau_i \in \mathcal{T}} \nabla_{\theta} \mathcal{L}_{\mathcal{D}_{\tau_i}^q} \left( f_{\theta_i, \phi_i} | \pi^* \right) \\ &= \theta - \beta \cdot \sum_{\tau_i \in \mathcal{T}} \nabla_{\theta_i} \mathcal{L}_{\mathcal{D}_{\tau_i}^q} \left( f_{\theta_i, \phi_i} | \pi^* \right) \cdot \nabla_{\theta} (\theta_i) \\ &= \theta - \beta \cdot \sum_{\tau_i \in \mathcal{T}} \nabla_{\theta_i} \mathcal{L}_{\mathcal{D}_{\tau_i}^q} \left( f_{\theta_i, \phi_i} | \pi^* \right) \cdot \left( \mathbf{I} - \alpha \mathbf{H}_{\theta} \left( \mathcal{L}_{\mathcal{D}_{\tau_i}^S} \left( f_{\theta_i, \phi_i} | \pi^* \right) \right) \right) \end{align*} $$ - 其中$H_{\theta}(\mathcal{L})$為[Hessian matrix](https://zh.wikipedia.org/zh-tw/%E9%BB%91%E5%A1%9E%E7%9F%A9%E9%99%A3),可以表示為$$ \left( \mathbf{H}_{\theta}(\mathcal{L}) \right)_{ij} = \frac{\partial^2 \mathcal{L}}{\partial \theta_i \partial \theta_j}; $$ - 前面都在說$\theta$怎麼更新沒有提到$\phi$,而$\phi也可以用一樣的方式進行更新 - 在prompt learning的領域裡,the task head也就是 the answering function,連結了重構的下游任務的promtpt與答案 - the answering function 可以是可調式或手工式的樣板 - 這部分 Section3.5有一個簡單有效且不需要任何tunable task head的手工prompt回答樣板 #### 3.4.3 Overall Learning Process - 為了增加學習的穩定性,訓練時的的multi-task episodes每個episode內涵蓋了batch tasks(內有結點分類任務("$n$")、邊分類任務("$l$")和圖分類任務("$g$")) - $\mathcal{E}_i = \left( \mathcal{T}_{\mathcal{E}_i}, \mathcal{L}_{\mathcal{E}_i}, \mathcal{S}_{\mathcal{E}_i}, \mathcal{Q}_{\mathcal{E}_i} \right)$為multi-task pisode - $\mathcal{T}_{\mathcal{E}_i} = \left\{ \mathcal{T}_{\mathcal{E}_i}^{(g)}, \mathcal{T}_{\mathcal{E}_i}^{(n)}, \mathcal{T}_{\mathcal{E}_i}^{(\ell)} \right\}$為task batch - 子集為$\mathcal{T}_{\mathcal{E}_i}^{(\triangleleft)} = \left\{ \tau_{\triangleleft1}, \cdots, \tau_{\triangleleft t_{\triangleleft}} \right\}$ - Loss function定義為$\mathcal{L}_{\mathcal{E}_i} = \left\{ \mathcal{L}^{(g)}, \mathcal{L}^{(n)}, \mathcal{L}^{(\ell)} \right\}$ - 相關資料(Supporting Data)$\mathcal{S}_{\mathcal{E}_i} = \left\{ \mathcal{S}_{\mathcal{E}_i}^{(g)}, \mathcal{S}_{\mathcal{E}_i}^{(n)}, \mathcal{S}_{\mathcal{E}_i}^{(\ell)} \right\}$ - 其子集為$\mathcal{S}_{\mathcal{E}_i}^{(\triangleleft)} = \left\{ \mathcal{D}_{\tau_1^{(\triangleleft)}}^{s}, \cdots, \mathcal{D}_{\tau_{t^{(\triangleleft)}}}^{s} \right\}$ - query資料為$\mathcal{Q}_{\mathcal{E}_i} = \left\{ \mathcal{Q}_{\mathcal{E}_i}^{(g)}, \mathcal{Q}_{\mathcal{E}_i}^{(n)}, \mathcal{Q}_{\mathcal{E}_i}^{(\ell)} \right\}$ - 其子集為$\mathcal{S}_{\mathcal{E}_i}^{(\triangleleft)} = \left\{ \mathcal{D}_{\tau_{\triangleleft_1}}^{q}, \cdots, \mathcal{D}_{\tau_{\triangleleft_{t_\triangleleft}}}^{q} \right\}$ - 以上定義完成後,multi-task prompting如Algorithm 1那樣呈現,將node/edge/graph class以二元分類任務處理,以便使不同結構的任務可以共用task head;這篇的方法也可以處理分類以外的其他任務(詳見Appendix A) ![image](https://hackmd.io/_uploads/B1ThG4NQ1l.png) ### 3.5 Why It Works? #### 3.5.1 Connection to Existing Work - 先前的研究GPPT有提及graph prompt,他們以edge prediction作為預訓練的pretext並且藉由設計標註過的tockens再加到original graph上重構node classification成pretext - 這種複合的graph會被再次送進預訓練模型裡面進行連結到labeled node的各個節點預測 - 對本研究來說這個是個特例,比較起來,這一篇的研究graph裡面只包含各個孤立的tokens(每個都會對應到特定的node類別) - 本篇以及GPPT的差別可以分為三者 1. GPPT對原始的Graph操作並不靈活 2. GPPT只能應用在節點分類 3. GPPT只支援edge prediction任務且與其他graph類的預訓練策略不相容(GraphCL [36],UGRAPHEMB [2], SimGRACE [35] etc.) - 下面會進一步探討 靈活度、效率和相容性 #### 3.5.2 Flexibility - prompting本身是用來讓input data跟pretext連結起來 - 因此資料本身可進行的操作會是影響模型表現的瓶頸 - $g$為graph-level的轉換(舉例“changing node features”,“adding or removing edges/subgraphs” etc.) - $\psi$是參數凍結的預訓練graph model - 對任何[鄰接矩陣](https://zh.wikipedia.org/zh-tw/%E9%82%BB%E6%8E%A5%E7%9F%A9%E9%98%B5) $A$ 和節點特徵矩陣 $X$ 有 $\mathcal{G}$,Fang et al.已經證實可以學到合適的prompt token $p^*$並使公式成立$\phi^*(\mathbf{A}, \mathbf{X} + \mathbf{p}^*) = \phi^*(g(\mathbf{A}, \mathbf{X})) + O_{p\phi}$ - 這邊的涵義代表我們可以學到套在原始graph上的token以模仿graph上的操作 - $O_{p{\phi}}$代表經過操作後的graph(manipulated graph)和prompting graph的誤差邊界(error bound) - 誤差邊界本身與模型(unchangable)的非線性層還有learned prompt(changable)的品質有關 - 有望縮小 advanced prompt scheme的範圍(也就是更快找到好用的prompt) - 這篇的研究者將獨立的token擴展到prompt graph以透過內部可學習的結構來操作多個prompt tokens(白話文是透過模型學會並且操作prompt token) - 而非隨意在等式$\phi^*(\mathbf{A}, \mathbf{X} + \mathbf{p}^*) = \phi^*(g(\mathbf{A}, \mathbf{X})) + O_{p\phi}$中插入 - "$X+p^*$"代表prompt token應該被加到原始graph的每個節點上 - 插入的模式(the insertion pattern)是高度客製化(自定義化?)的 - $\psi(\mathcal{G,G_p})$為Section 3.3的插入模式 - $\mathcal{G}$為原始的graph - $\mathcal{G}^*_p$為prompt graph - 前面的等式可以進一步擴展成$\phi^*\left(\psi(G, G_p^*)\right) = \phi^*\left(g(\mathbf{A}, \mathbf{X})\right) + O_{p\phi}^*$ - 藉由有效地微調,新的誤差邊界$O^*_{p\phi}$可能在經過有效訓練遠小於$O_{p\phi}$ #### 3.5.3 Efficiency - 假設input graph有N個節點和M個邊,而prompt token和m條edges - 而graph model有L層且所有層數理最大的維度為d - prompt graph的參數複雜度為$O(nd)$ - 對比來說,某些graph models(像是GAT)的參數複雜度通常為$O(LKd^2+LKd)$已生成節點的embeddings和額外$O(Kd)$的參數以舉得整體graph的embedding(K為multi-head 的數量,head數?) - 這些參數甚至會大於其他GNN網路(像是graph transformer) - 在本篇中的prompt learning架構裡,我們只需要調整參數凍結的預訓練模型的prompt來使得訓練過程比傳統的傳輸、微調收斂過程更快就好 - 在時間複雜度上,一般graph model(GCN)需要$O(LNd^2+LMd+Nd)$來產生節點的embedding(via message passing and then obtain the whole grap representation(e.g., 𝑂(𝑁𝑑) for summation pooling).) -藉由將prompt 插入原始的graph,整體時間複雜度會變成$O(L(n+N)d^2 + L(m+M)d + (n+N)d)$,比起原始的時間複雜度來說額外的時間消耗只佔了$O(Lnd^2 + Lmd + nd) \text{ where } n \ll d, n \ll N, m \ll M$ - 除了前面參數效率上和時間消耗上的差異之外,這篇研究也占了更小的記憶體消耗 - 以節點分類為例,運行模型本身比其其他GNN模型只需要快取prompt 的參數而非整體。比起要將整格graph餵進去graph model裡面,這邊只需要把個別節點的induced graph餵進去就好,這樣做所消耗記憶體遠小於把原始的graph送進去(以現實世界來說,本來就不是每個節點的資訊都很重要,因此本篇的方法在沒有額外的節點需要預測的話就不需要再跑額外的資源了) #### 3.5.4 Compatibility相容性 - 比起GPPT只能使用二元的edge(也就是有無edge)作為標籤(pretext)進行預測,還有只能使用在節點分類的任務上;本篇研究可以應用到節點、邊、圖各種層級的任務上而且只需要挑整一下就能適應各種資料結構了 - 比起傳統的模型使用需要調整task head,本篇研究更專注在操作輸入資料而不是下游任務 - 以Section 4.3為例,在研究模型的可轉移性(也就是把這個模型應用到其他領域上的任務時的表現)只調整了prompt,sourced task head - 這篇甚至可以選擇特定的標籤之後再自定義prompt的細節,也不需要調整task head - 以下為範例 > **Prompt without Task Head Tuning:** Pretext: GraphCL [36], a graph contrastive learning task that tries to maximize the agreement between a pair of views from the same graph. **Downstream Tasks:** node/edge/graph classification. **Prompt Answer:** node classification. Assume there are 𝑘 categories for the nodes. We design the prompt graph with 𝑘 sub-graphs (a.k.a sub-prompts) where each sub-graph has 𝑛 tokens. Each sub-graph corresponds to one node category.Then we can generate 𝑘 graph views for all input graphs.We classify the target node with label ℓ (ℓ = 1, 2, · · · , 𝑘) if the ℓ-th graph view is closest to the induced graph. It is similar to edge/graph classification. - 有趣的是縮短prompt成對應到節點類別和把原始的graph代換調induced graph獨立的token之後,本篇研究可以變相地變成GPPT(這邊不再贅述) ## 4 EVALUATION 在此節中,作者將比較論文中提出的方法與其他方式在不同 level 的 task 中的表現,並得出以下結果 1. 在 few-shot learning 的情境下,於多任務中的性能表現 2. 在不同 domain 或 task 遷移的時候,具有什麼樣的表現 3. prompt token, token structure, inserting pattern 這三個 component 對效能具有什麼樣的影響 4. 與 traditional approach 相比,論文中提出的方法提升的多少性能 5. 在對 graph 進行操作時具有什麼樣的表現 ### 4.1 Experimental Settings #### 4.1.1 Datasets - 以下是研究中使用的dataset ##### Table 1: Statistics of datasets | Dataset | #Nodes | #Edges | #Features | #Labels | |-----------|---------|--------------|-----------|---------| | Cora | 2,708 | 5,429 | 1,433 | 7 | | CiteSeer | 3,327 | 9,104 | 3,703 | 6 | | Reddit | 232,965 | 23,213,838 | 602 | 41 | | Amazon | 13,752 | 491,722 | 767 | 10 | | Pubmed | 19,717 | 88,648 | 500 | 3 | - 而為了處理edge level 或graph level的任務,研究者從原始資料中取樣(取邊跟子圖,其中Edge的標籤由兩個端點決定;子圖的標籤則由內部大多數節點決定:舉例來說,如果某節點有$c_1,c_2,c_3$三種類別,則Edge-level則至少會有$c_1,c_2,c_3,c_1c_2,c_2c_3,c_1c_3$這幾種類別,也就是$C^n_2$種) - 研究者也額外針對graph標籤和連結(link)標籤預測跑過其他特別類型的資料集 - 這邊的graph標籤和link標籤是內建的而非由其他資料計算而來(詳見Appendix A) #### 4.1.2 Approaches - 比較以下三種方式的性能 - Supervised methods - 單純使用以下模型進行訓練 - GAT, GCN, Graph Transformer(GT) - Pre-training with fine-tunning - 使用 SimGrace 或 GraphCL 訓練 GAT, GCN, GT 等 GNN Model - Self-supervised learning - Prompt methods (本論文中提出的方式) - 使用 SimGrace 或 GraphCL 訓練 GAT, GCN, GT 等 GNN Model - 針對 input graph 加上 prompt graph - 並對 downstream task 做轉換成 graph-level task #### 4.1.3 Implementations - GNN的層數為2;隱藏層數為100 - 為了研究不同graph資料的可轉移性,研究者使用[SVD](https://zh.wikipedia.org/zh-tw/%E5%A5%87%E5%BC%82%E5%80%BC%E5%88%86%E8%A7%A3)減少初始的特徵數量到100維 - Prompt graph的token數量為10 - 探討token數量影響的部分在Section4.4,範圍為1-20 - 全部實驗的優化器皆是使用[Adam optimizer](https://arxiv.org/abs/1412.6980) - 學習率(Learning rate被設定為)0.001(大多資料集並非全部) - 在meta learning 中所有node-level, edge-level, and graph-level任務在meta-training和meta-testing上的比例為1:1 - 更多細節都在Appendix A裡面,包含更多資料集和迴歸或link prediction任務的實驗結果 #### 4.1.1 Datasets - Cora, CiteSeet, Reddit, Amazon, Pubmed - 為了利用這些 Dataset 做 edge-level 與 graph-level tasks,在 edge 以及 subgraph 上加上了 label - edge-level: 取兩端的 node 的 label - graph-level: 取 subgraph 中擁有多數 node 的 label <--這邊我寫好了--> #### 4.1.2 Approaches - Supervised methods - 單純使用以下模型進行訓練 - GAT, GCN, Graph Transformer(GT) - Pre-training with fine-tunning - 使用 SimGrace 或 GraphCL 訓練 GAT, GCN, GT 等 GNN Model - Self-supervised learning - Prompt methods (本論文中提出的方式) - 使用 SimGrace 或 GraphCL 訓練 GAT, GCN, GT 等 GNN Model - 針對 input graph 加上 prompt graph - 並對 downstream task 做轉換成 graph-level task #### 4.1.3 ### 4.2 Multi-Task Performance with Few-shot Learning Settings (RQ1) - Prompt Learning被拿來與其他主流的訓練方式比較,使用的是node-level, edge-level, and graph-level tasksunder the few-shot setting - 實驗重複5次並且取平均後結果在Table 2,Table 12,Table 13 - 結果上監督式學習難以跟其他訓練方式抗衡 - 因為監督式學習很需要先標註的資料而在few shot裡這件事非常受限也導致了很爛的結果 - 相對來說預訓練包含了很多先驗知識,沒那麼依賴資料的標籤 - 然而這個問題還蠻難處理的,預訓練方法和微調模型的過程需要小心處理以處理;而這部分的努力依然不確定是否有效,預訓練策略和下游任務之間的鴻溝依然很大,GNN模型依然很難轉移到Multi-tasking上(Section 4.3會進一步討論) - 在預訓練方法上,不同資料結構的表現有所提升 - from 1.10% to 8.81% on node-level tasks, - from 1.28% to 12.26% on edge-level tasks, - from 0.14% to 10.77% on graph-level tasks - 其中node-level在GPPT與本篇研究的比較在Table 2 - 要注意的是兩者的設定差很多 - GPPT在few shots把30%-50%的標籤換成[MASK] - 而本篇研究提出另一個更具挑戰性的問題: 在減少標籤的狀況下模型會表現如何? - 因此實驗時每一個類別都只有100筆有標籤的樣本 - 這個設定讓 - Cora資料只有25%有標籤; - CiteSeer資料只有18%有標籤; - Reddit資料大概只有1.7%有標籤; - Amazon資料大概只有7.3%有標籤; - Pubmed資料大概只有1.5%有標籤; ### 4.3 Transferability Analysis (RQ2) - 為了評估transferability(可轉移性),研究中比較了hard transfer和微調過的版本 - Here the hard transfer method means we seek the source task model which has the same task head as the target task and then we directly conduct the model inference on the new task. - The fine-tune method means we load the source task model and then tune the task head for the new task. - 評估可轉移性的觀點如下 1. 模型轉移到同一領域的不同任務後表現如何? 2. 模型轉移到其他領域後的表現如何? #### 4.3.1 Transferability to Different Level Tasks. - 首先研究者拿Amazon資料集訓練GNN後再讓模型處理Graph level 跟node level;隨後再評估在edge level上的表現 - 無論哪個任務都被作為二元分類任務,為該類為正其餘皆為負 - 結果在Table 3 - 觀察如下 1. prompt method成效十分顯著且判斷依據相當合理 2. hard transfer表現不好,原因是source task上的source class跟 target task的target class差異太大 - 這甚至會導致負遷移 - 大多數案例中微調都能使模型輸出有意義的結果 - 但依然可能會有負遷移的問題 3. graph-level task比node-level task 在edge-level target上會有更好的適應性 -以上符合先前直覺性的結果in Figure 3 (section 3.2). #### 4.3.2 Transferability to Different Domains. - 研究者將Amazon and PubMed作為來源領域,再將這些領域的模型載入並且去跑Cora後輸出其表現狀況 - 由於不同資料集的輸入特徵維度不同,所以實驗時使用SVD來統一輸入的特徵到100維 - 結果在Table 4 - 可以發現 prompt能讓模型有很好的可轉移性 ### 4.4 Ablation Study (RQ3) - 這邊比較了原本的實驗方法和以下四個變體 #### “w/o meta” - prompt method但不包含meta-learning #### “w/o h”(section 3.5.4) - our method 但不包含 task head tuning, #### “w/o token structure” - 有prompt而所有token都各別獨立彼此沒有內部連結 #### “w/o inserting” - 有prompt但prompt tokens與the input graphs之間沒有任何across link - 結果在Figure 5 - 可以發現 meta-learning 跟 token structure對於最終結果很有貢獻 - the inserting pattern between a prompt graph and the input graph扮演很重要的腳色 - 如先前提及the prompt based method緩解了預訓練和微調時預訓練模型與task head之間的鴻溝 - prompt graph可以改進微調成效 - 如Figure 5提及,“w/o h”變體依然可以有還不錯的結果 - 顯示了連結上下游任務的能力 ### 4.5 Efficiency Analysis (RQ4) - Figure 6呈現出了增加 token number與模型表現之間的關係 - 在有限的token內模型可以給出令人滿意的結果 - 代表模型的複雜度小也能給初還不錯的結果 - 結果在Table 5 - 代表研究者的方法蠻有用的 ##### Table 5: Tunable parameters comparison **RED (%): average reduction of the prompt method to others** | Methods | Cora | CiteSeer | Reddit | Amazon | Pubmed | RED (%) | |---------|----------|----------|---------|---------|---------|---------| | GAT | ~155K | ~382K | ~75K | ~88K | ~61K | 95.4↓ | | GCN | ~154K | ~381K | ~75K | ~88K | ~61K | 95.4↓ | | GT | ~615K | ~1.52M | ~286K | ~349K | ~241K | 98.8↓ | | prompt | ~7K | ~19K | ~3K | ~4K | ~3K | — | - 如Figure 7呈現 the prompt-based method會比傳統的預訓練和監督式學習方法收斂得更快 ### 4.6 Flexibility on Graph Transformation (RQ5) #### Table 5: Tunable parameters comparison. RED (%): average reduction of the prompt method to others. - 先前Section 3.5.2,資料轉換的靈活性是prompt-based methods的瓶頸 - 這邊研究者通過刪除結點、刪除邊和遮住特徵來操作graph,會後再繼續用先前的公式計算誤差邊界 - 這邊比較the original error with the naive prompt mentioned in Equation 5和and our prompt graph with 3, 5, and10 tokens. - 如Table 6所示,研究者的designed prompt graph大幅降低the original graph and the manipulated graph. - 證實研究者的方法足以使多樣的graph轉換且可以更進一步地使下游任務的表現進步 - 這點可以在圖視覺化後觀察到 - 如圖8所示,預訓練模型的graph表示和prompted graph的graph表示比起來,其節點類別的分辨率較低 ## 5 CONCLUSION - 這篇paper在探討使用few shots進行graph prompts的Multi-task - 提出了新的方法重構不同層級的任務並且統一格式 - 應用meta-learning技術設計了有效的prompt graph - 評出了成效並證明有效性 ## A APPENDIX - 本篇做實驗的[程式碼](https://anonymous.4open.science/r/mpg) ### Additional Datasets - 除了主要實驗中的資料集,Table 7有更多資料集評估本架構的成效 - ENZYMES和ProteinsFull為molecule/protein(分子/蛋白質)的資料集, 用在 graph-level classification - Movielens和 QM9則個別是 edge-level and graph-level regression上進行評估 - Movielens包含電影和用戶評分,每個edge代表從0到5的 score value - QM9是分子的資料集還有19個迴歸目標,作為graph-level multi-output regression - PersonalityCafe和Facebook 用於 link prediction - 社交網路類資料集的 edges代表 following/quoting relations. ### Multi-label v.s. Multi-class Classification - 主要的實驗中將分類任務當成multi-label problem,這邊則是將用multi-class的設定去跑 - 如Table 8所示,prompt-based method比其他好 ### Additional Graph-level Classification - 這邊評估graph label不被結點屬性影響的graph-level classification - 如Table 9所示,研究方法比multi-class graph classification有效,尤其在few-shot setting ### Edge/Graph-level Regression - 這邊用MAE (mean absolute error) and MSE (mean squared error)評估graphlevel (QM9) and edge-level (MovieLens) datasets - 這邊餵給模型100 shot的edge induced graphs - 結果在Table 10 - 可以觀察到prompt-based methods比其他的好 ### Link Prediction - link prediction也在graph的學習領域被廣泛研究 - 這邊edge分成三個部分 1. 80% of the edges are for message passing only 2. 10% of the rest edges as the supervision training set. 3. the rest edges as the testing set - 對training set跟test set的每個edge都被當成正樣本並取樣非鄰接的節點為負樣本 - 根據第一步份的edges生成節點對(node pairs)的edge-induced graph - 如果節點對有為正的edge則graph的標籤為正 - 極端測試,指取樣100個為正的edge作為training set - 在測試階段每個為正的edge會有100個為負的edge - 這邊用[MRR](https://en.wikipedia.org/wiki/Mean_reciprocal_rank)評估表現 - 這邊還分別計算Hit Ratio@ 1, 5, 10(也就是不同命中率)時的狀況 - Table 11的結果顯示rompt-based method有用