# Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning
###### tags: `RL Group meeting`
## Outline
- Introduction
- Approach
- Experiment
- Conclusion
## Introduction
- Prompt tuning (PT) prepends tunable continuous prompt vectors to the input, has emerged as a promising approach for parameter-efficient transfer learning with PLMs.
- We propose **multitask prompt tuning (MPT)**, which first learns a single transferable prompt by distilling knowledge from multiple task-specific source prompts.
- We learn multiplicative **low-rank** updates to this shared prompt to efficiently adapt it to each downstream target task.
- We decompose the soft prompt of each source task into a multiplication of a shared matrix and a low-rank task-specific matrix.
- This decomposition is more effective than simply sharing the prompt matrix across all tasks.

## Approach
- Our goal is to learn a single soft prompt over $S$ that can be adapted to each task $\mathcal{T_i}$ in a parameter-efficient way.
- Simply training a single soft prompt on $S$ and then fine-tuning on each $\mathcal{T_i}$ is sub-optimal as it can fail to leverage commonalities across source tasks while minimizing interference at the same time.
- MPT aims to compress task-shared knowledge in $S$ into a single prompt matrix $ϕ_S$ via knowledge distillation to improve performance on $\mathcal{T}$ while filtering out task-specific information that is less useful for transfer learning.
- Prompt tuning randomly initializes a small number of learnable prompt vectors to be prepended to the input embeddings of the PLM while freezing model parameters Θ.
$$
\mathcal{L}_{\mathrm{PLM}}=-\sum_i \log P\left(\boldsymbol{y}_i \mid \boldsymbol{x}_i ; \Theta, \boldsymbol{P}\right)
$$
### Multitask prompt tuning
- Prompt matrices for the source tasks are decomposed into a task-shared matrix and a low-rank task-specific matrix (prompt decomposition), where the former is shared across all tasks.
- This decomposition into shared and task-specific components is learned through knowledge distillation.
- Once learned, the shared prompt matrix is adapted to a downstream target task via low-rank multiplicative updates.
#### prompt decomposition

- Let $\boldsymbol{P}^* \in \mathbb{R}^{l \times d}$ denote the shared prompt across all tasks.
- Let $\boldsymbol{u}_k \in \mathbb{R}^l, \boldsymbol{v}_k \in \mathbb{R}^d$ be the task-specific vectors for each task $k$.
$$
\widehat{\boldsymbol{P}}_k=\boldsymbol{P}^* \circ \boldsymbol{W}_k=\boldsymbol{P}^* \circ\left(\boldsymbol{u}_k \otimes \boldsymbol{v}_k^T\right)
$$
#### Prompt distillation
- We first obtain a teacher prompt $P^{(teacher)}_k$ for the k-th source task by conventional prompt tuning.
- We randomly initialize a corresponding student prompt as $\widehat{P}_k$ ,where all student prompts share $P^∗$ and have their own task-specific vectors as described above.
- We use distillation to transfer cross-task knowledge into the shared prompt matrix.
- The loss is to match the output probability distributions of students and teachers:
$$
\mathcal{L}_{\text {Logits }}=\sum_{k \in|\mathcal{S}|} \sum_{\left(\boldsymbol{x}_i, \boldsymbol{y}_i\right) \in \mathcal{S}_k} \operatorname{KL}\left[P\left(\boldsymbol{y}_i \mid \boldsymbol{x}_i ; \Theta, \boldsymbol{P}_k^{(\text {teacher })}\right) \| P\left(\boldsymbol{y}_i \mid \boldsymbol{x}_i ; \Theta, \widehat{\boldsymbol{P}}_k\right)\right]
$$
$$
p_j=\frac{1}{Z} \exp \left(z_j / T\right)
$$
- Mean squared loss on teacher model hidden states:
$$
\mathcal{L}_{\text {Hidden }}=\sum_{k \in|\mathcal{S}|} \sum_{\left(\boldsymbol{x}_i, \boldsymbol{y}_i\right) \in \mathcal{S}_k}\left(\boldsymbol{H}_{k, i}-\boldsymbol{H}_{k, i}^{(\text {teacher })}\right)^2
$$
- The total loss function for training student source prompts:
$$
\mathcal{L}_{\text {Total }}=\mathcal{L}_{\text {PLM }}+\lambda\left(\mathcal{L}_{\text {Logits }}+\mathcal{L}_{\text {Hidden }}\right)
$$
### Source training
1. The teacher prompts for all source tasks are pretrained individually through vanilla prompt tuning.
2. We perform multitask training on **$S$** = $\{S_1, . . . , S_κ\}$ to jointly learn the single shared prompt via the knowledge distillation loss function.
- For each batch of multitask samples, we randomly select a number $K$ from [2, κ] first, then randomly choose K tasks from S and their corresponding samples to constitute mini-batches.
### Target adaptation
- We initialize the target prompt for target task $\mathcal{T_t}$ to be the Hadamard product of the shared prompt matrix and the task-specific low-rank prompt matrix, and optimize with the regular task loss.
#### Parameter-efficiency
- The total number of tunable parameters for a single target task is $(l * d) + (l + d)$
- After training, this can further be compressed into a single matrix of size $l * d^2$
- For a group of target tasks, the total number of tunable parameters is $(l * d) + (l + d)τ$ , where $τ$ is the number of target tasks.
## Experiment
- We mainly experiment using the publicly available pretrained T5-Base model with 220M parameters.
- Full-dataset adaptation


- Few-shot adaptation

- Natural language generation tasks

- Prompt decomposition and distillation

## Conclusion
- MPT learns a single transferable prompt by **decomposing and distilling knowledge** from multiple source tasks and their task-specific source prompts.
- MPT decomposes the task prompt as the Hadamard product of a shared prompt matrix and a rank-one task-specific matrix.
- The **shared component** is then transferred and adapted to target tasks for further tuning.
- We found this approach enables parameter-efficient transfer learning to target downstream tasks across diverse NLP benchmarks.
## Appendix
- Vanilla PoT first trains a soft prompt $u_s$ on one source task and then uses the trained prompt to initialize the prompt for a target task.
- Lastly, the prompt initialized with $u_s$ is further fine-tuned on the target task to obtain the task-specific target prompt $u_t$

## Reference
[Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://openreview.net/pdf?id=Nk2pDtuhTq)