# Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning ###### tags: `RL Group meeting` ## Outline - Introduction - Approach - Experiment - Conclusion ## Introduction - Prompt tuning (PT) prepends tunable continuous prompt vectors to the input, has emerged as a promising approach for parameter-efficient transfer learning with PLMs. - We propose **multitask prompt tuning (MPT)**, which first learns a single transferable prompt by distilling knowledge from multiple task-specific source prompts. - We learn multiplicative **low-rank** updates to this shared prompt to efficiently adapt it to each downstream target task. - We decompose the soft prompt of each source task into a multiplication of a shared matrix and a low-rank task-specific matrix. - This decomposition is more effective than simply sharing the prompt matrix across all tasks. ![](https://hackmd.io/_uploads/HktmMotDh.png) ## Approach - Our goal is to learn a single soft prompt over $S$ that can be adapted to each task $\mathcal{T_i}$ in a parameter-efficient way. - Simply training a single soft prompt on $S$ and then fine-tuning on each $\mathcal{T_i}$ is sub-optimal as it can fail to leverage commonalities across source tasks while minimizing interference at the same time. - MPT aims to compress task-shared knowledge in $S$ into a single prompt matrix $ϕ_S$ via knowledge distillation to improve performance on $\mathcal{T}$ while filtering out task-specific information that is less useful for transfer learning. - Prompt tuning randomly initializes a small number of learnable prompt vectors to be prepended to the input embeddings of the PLM while freezing model parameters Θ. $$ \mathcal{L}_{\mathrm{PLM}}=-\sum_i \log P\left(\boldsymbol{y}_i \mid \boldsymbol{x}_i ; \Theta, \boldsymbol{P}\right) $$ ### Multitask prompt tuning - Prompt matrices for the source tasks are decomposed into a task-shared matrix and a low-rank task-specific matrix (prompt decomposition), where the former is shared across all tasks. - This decomposition into shared and task-specific components is learned through knowledge distillation. - Once learned, the shared prompt matrix is adapted to a downstream target task via low-rank multiplicative updates. #### prompt decomposition ![](https://hackmd.io/_uploads/rJfw2_2Pn.png) - Let $\boldsymbol{P}^* \in \mathbb{R}^{l \times d}$ denote the shared prompt across all tasks. - Let $\boldsymbol{u}_k \in \mathbb{R}^l, \boldsymbol{v}_k \in \mathbb{R}^d$ be the task-specific vectors for each task $k$. $$ \widehat{\boldsymbol{P}}_k=\boldsymbol{P}^* \circ \boldsymbol{W}_k=\boldsymbol{P}^* \circ\left(\boldsymbol{u}_k \otimes \boldsymbol{v}_k^T\right) $$ #### Prompt distillation - We first obtain a teacher prompt $P^{(teacher)}_k$ for the k-th source task by conventional prompt tuning. - We randomly initialize a corresponding student prompt as $\widehat{P}_k$ ,where all student prompts share $P^∗$ and have their own task-specific vectors as described above. - We use distillation to transfer cross-task knowledge into the shared prompt matrix. - The loss is to match the output probability distributions of students and teachers: $$ \mathcal{L}_{\text {Logits }}=\sum_{k \in|\mathcal{S}|} \sum_{\left(\boldsymbol{x}_i, \boldsymbol{y}_i\right) \in \mathcal{S}_k} \operatorname{KL}\left[P\left(\boldsymbol{y}_i \mid \boldsymbol{x}_i ; \Theta, \boldsymbol{P}_k^{(\text {teacher })}\right) \| P\left(\boldsymbol{y}_i \mid \boldsymbol{x}_i ; \Theta, \widehat{\boldsymbol{P}}_k\right)\right] $$ $$ p_j=\frac{1}{Z} \exp \left(z_j / T\right) $$ - Mean squared loss on teacher model hidden states: $$ \mathcal{L}_{\text {Hidden }}=\sum_{k \in|\mathcal{S}|} \sum_{\left(\boldsymbol{x}_i, \boldsymbol{y}_i\right) \in \mathcal{S}_k}\left(\boldsymbol{H}_{k, i}-\boldsymbol{H}_{k, i}^{(\text {teacher })}\right)^2 $$ - The total loss function for training student source prompts: $$ \mathcal{L}_{\text {Total }}=\mathcal{L}_{\text {PLM }}+\lambda\left(\mathcal{L}_{\text {Logits }}+\mathcal{L}_{\text {Hidden }}\right) $$ ### Source training 1. The teacher prompts for all source tasks are pretrained individually through vanilla prompt tuning. 2. We perform multitask training on **$S$** = $\{S_1, . . . , S_κ\}$ to jointly learn the single shared prompt via the knowledge distillation loss function. - For each batch of multitask samples, we randomly select a number $K$ from [2, κ] first, then randomly choose K tasks from S and their corresponding samples to constitute mini-batches. ### Target adaptation - We initialize the target prompt for target task $\mathcal{T_t}$ to be the Hadamard product of the shared prompt matrix and the task-specific low-rank prompt matrix, and optimize with the regular task loss. #### Parameter-efficiency - The total number of tunable parameters for a single target task is $(l * d) + (l + d)$ - After training, this can further be compressed into a single matrix of size $l * d^2$ - For a group of target tasks, the total number of tunable parameters is $(l * d) + (l + d)τ$ , where $τ$ is the number of target tasks. ## Experiment - We mainly experiment using the publicly available pretrained T5-Base model with 220M parameters. - Full-dataset adaptation ![](https://hackmd.io/_uploads/SJ9mkraPn.png) ![](https://hackmd.io/_uploads/B1bp1Bpv3.png) - Few-shot adaptation ![](https://hackmd.io/_uploads/HJHjbB6vn.png) - Natural language generation tasks ![](https://hackmd.io/_uploads/B1Fo7BaD3.png) - Prompt decomposition and distillation ![](https://hackmd.io/_uploads/B18vVH6v2.png) ## Conclusion - MPT learns a single transferable prompt by **decomposing and distilling knowledge** from multiple source tasks and their task-specific source prompts. - MPT decomposes the task prompt as the Hadamard product of a shared prompt matrix and a rank-one task-specific matrix. - The **shared component** is then transferred and adapted to target tasks for further tuning. - We found this approach enables parameter-efficient transfer learning to target downstream tasks across diverse NLP benchmarks. ## Appendix - Vanilla PoT first trains a soft prompt $u_s$ on one source task and then uses the trained prompt to initialize the prompt for a target task. - Lastly, the prompt initialized with $u_s$ is further fine-tuned on the target task to obtain the task-specific target prompt $u_t$ ![](https://hackmd.io/_uploads/SJFRl5pD3.png) ## Reference [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://openreview.net/pdf?id=Nk2pDtuhTq)