Try   HackMD

Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning

tags: RL Group meeting

Outline

  • Introduction
  • Approach
  • Experiment
  • Conclusion

Introduction

  • Prompt tuning (PT) prepends tunable continuous prompt vectors to the input, has emerged as a promising approach for parameter-efficient transfer learning with PLMs.
  • We propose multitask prompt tuning (MPT), which first learns a single transferable prompt by distilling knowledge from multiple task-specific source prompts.
  • We learn multiplicative low-rank updates to this shared prompt to efficiently adapt it to each downstream target task.
    • We decompose the soft prompt of each source task into a multiplication of a shared matrix and a low-rank task-specific matrix.
    • This decomposition is more effective than simply sharing the prompt matrix across all tasks.
      Image Not Showing Possible Reasons
      • The image was uploaded to a note which you don't have access to
      • The note which the image was originally uploaded to has been deleted
      Learn More →

Approach

  • Our goal is to learn a single soft prompt over
    S
    that can be adapted to each task
    Ti
    in a parameter-efficient way.
  • Simply training a single soft prompt on
    S
    and then fine-tuning on each
    Ti
    is sub-optimal as it can fail to leverage commonalities across source tasks while minimizing interference at the same time.
  • MPT aims to compress task-shared knowledge in
    S
    into a single prompt matrix
    ϕS
    via knowledge distillation to improve performance on
    T
    while filtering out task-specific information that is less useful for transfer learning.
  • Prompt tuning randomly initializes a small number of learnable prompt vectors to be prepended to the input embeddings of the PLM while freezing model parameters Θ.
    LPLM=ilogP(yixi;Θ,P)

Multitask prompt tuning

  • Prompt matrices for the source tasks are decomposed into a task-shared matrix and a low-rank task-specific matrix (prompt decomposition), where the former is shared across all tasks.
  • This decomposition into shared and task-specific components is learned through knowledge distillation.
  • Once learned, the shared prompt matrix is adapted to a downstream target task via low-rank multiplicative updates.

prompt decomposition

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • Let
    PRl×d
    denote the shared prompt across all tasks.
  • Let
    ukRl,vkRd
    be the task-specific vectors for each task
    k
    .

P^k=PWk=P(ukvkT)

Prompt distillation

  • We first obtain a teacher prompt
    Pk(teacher)
    for the k-th source task by conventional prompt tuning.
  • We randomly initialize a corresponding student prompt as
    P^k
    ,where all student prompts share
    P
    and have their own task-specific vectors as described above.
  • We use distillation to transfer cross-task knowledge into the shared prompt matrix.
  • The loss is to match the output probability distributions of students and teachers:
    LLogits =k|S|(xi,yi)SkKL[P(yixi;Θ,Pk(teacher ))P(yixi;Θ,P^k)]

pj=1Zexp(zj/T)

  • Mean squared loss on teacher model hidden states:
    LHidden =k|S|(xi,yi)Sk(Hk,iHk,i(teacher ))2
  • The total loss function for training student source prompts:
    LTotal =LPLM +λ(LLogits +LHidden )

Source training

  1. The teacher prompts for all source tasks are pretrained individually through vanilla prompt tuning.
  2. We perform multitask training on
    S
    =
    {S1,...,Sκ}
    to jointly learn the single shared prompt via the knowledge distillation loss function.
  • For each batch of multitask samples, we randomly select a number
    K
    from [2, κ] first, then randomly choose K tasks from S and their corresponding samples to constitute mini-batches.

Target adaptation

  • We initialize the target prompt for target task
    Tt
    to be the Hadamard product of the shared prompt matrix and the task-specific low-rank prompt matrix, and optimize with the regular task loss.

Parameter-efficiency

  • The total number of tunable parameters for a single target task is
    (ld)+(l+d)
  • After training, this can further be compressed into a single matrix of size
    ld2
  • For a group of target tasks, the total number of tunable parameters is
    (ld)+(l+d)τ
    , where
    τ
    is the number of target tasks.

Experiment

  • We mainly experiment using the publicly available pretrained T5-Base model with 220M parameters.

  • Full-dataset adaptation

    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →

    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →

  • Few-shot adaptation

    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →

  • Natural language generation tasks

    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →

  • Prompt decomposition and distillation

    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →

Conclusion

  • MPT learns a single transferable prompt by decomposing and distilling knowledge from multiple source tasks and their task-specific source prompts.
  • MPT decomposes the task prompt as the Hadamard product of a shared prompt matrix and a rank-one task-specific matrix.
  • The shared component is then transferred and adapted to target tasks for further tuning.
  • We found this approach enables parameter-efficient transfer learning to target downstream tasks across diverse NLP benchmarks.

Appendix

  • Vanilla PoT first trains a soft prompt
    us
    on one source task and then uses the trained prompt to initialize the prompt for a target task.
  • Lastly, the prompt initialized with
    us
    is further fine-tuned on the target task to obtain the task-specific target prompt
    ut

    Image Not Showing Possible Reasons
    • The image was uploaded to a note which you don't have access to
    • The note which the image was originally uploaded to has been deleted
    Learn More →

Reference

Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning