Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning
Outline
- Introduction
- Approach
- Experiment
- Conclusion
Introduction
- Prompt tuning (PT) prepends tunable continuous prompt vectors to the input, has emerged as a promising approach for parameter-efficient transfer learning with PLMs.
- We propose multitask prompt tuning (MPT), which first learns a single transferable prompt by distilling knowledge from multiple task-specific source prompts.
- We learn multiplicative low-rank updates to this shared prompt to efficiently adapt it to each downstream target task.
- We decompose the soft prompt of each source task into a multiplication of a shared matrix and a low-rank task-specific matrix.
- This decomposition is more effective than simply sharing the prompt matrix across all tasks.
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
Approach
- Our goal is to learn a single soft prompt over that can be adapted to each task in a parameter-efficient way.
- Simply training a single soft prompt on and then fine-tuning on each is sub-optimal as it can fail to leverage commonalities across source tasks while minimizing interference at the same time.
- MPT aims to compress task-shared knowledge in into a single prompt matrix via knowledge distillation to improve performance on while filtering out task-specific information that is less useful for transfer learning.
- Prompt tuning randomly initializes a small number of learnable prompt vectors to be prepended to the input embeddings of the PLM while freezing model parameters Θ.
Multitask prompt tuning
- Prompt matrices for the source tasks are decomposed into a task-shared matrix and a low-rank task-specific matrix (prompt decomposition), where the former is shared across all tasks.
- This decomposition into shared and task-specific components is learned through knowledge distillation.
- Once learned, the shared prompt matrix is adapted to a downstream target task via low-rank multiplicative updates.
prompt decomposition
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
- Let denote the shared prompt across all tasks.
- Let be the task-specific vectors for each task .
Prompt distillation
- We first obtain a teacher prompt for the k-th source task by conventional prompt tuning.
- We randomly initialize a corresponding student prompt as ,where all student prompts share and have their own task-specific vectors as described above.
- We use distillation to transfer cross-task knowledge into the shared prompt matrix.
- The loss is to match the output probability distributions of students and teachers:
- Mean squared loss on teacher model hidden states:
- The total loss function for training student source prompts:
Source training
- The teacher prompts for all source tasks are pretrained individually through vanilla prompt tuning.
- We perform multitask training on = to jointly learn the single shared prompt via the knowledge distillation loss function.
- For each batch of multitask samples, we randomly select a number from [2, κ] first, then randomly choose K tasks from S and their corresponding samples to constitute mini-batches.
Target adaptation
- We initialize the target prompt for target task to be the Hadamard product of the shared prompt matrix and the task-specific low-rank prompt matrix, and optimize with the regular task loss.
Parameter-efficiency
- The total number of tunable parameters for a single target task is
- After training, this can further be compressed into a single matrix of size
- For a group of target tasks, the total number of tunable parameters is , where is the number of target tasks.
Experiment
-
We mainly experiment using the publicly available pretrained T5-Base model with 220M parameters.
-
Full-dataset adaptation
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
-
Few-shot adaptation
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
-
Natural language generation tasks
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
-
Prompt decomposition and distillation
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
Conclusion
- MPT learns a single transferable prompt by decomposing and distilling knowledge from multiple source tasks and their task-specific source prompts.
- MPT decomposes the task prompt as the Hadamard product of a shared prompt matrix and a rank-one task-specific matrix.
- The shared component is then transferred and adapted to target tasks for further tuning.
- We found this approach enables parameter-efficient transfer learning to target downstream tasks across diverse NLP benchmarks.
Appendix
- Vanilla PoT first trains a soft prompt on one source task and then uses the trained prompt to initialize the prompt for a target task.
- Lastly, the prompt initialized with is further fine-tuned on the target task to obtain the task-specific target prompt
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
Reference
Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning