Try   HackMD

Tractable Control for Autoregressive Language Generation

2023/08/22

tags: RL Group meeting

Outline

  • Introduction
  • Related work
  • Guiding Autoregressive Generation with Tractable Probabilistic Models
  • Efficient Probabilistic Reasoning with Hidden Markov Models
  • Experiments
  • Conclusion

Introduction

  • Autoregressive large language models remain a major challenge to generate text that satisfies complex constraints:

    • Sampling from the conditional distribution
      Pr(text|α)
      is intractable for even the simplest lexical constraints
      α
      .
  • We propose to use tractable probabilistic models (TPMs) to impose lexical constraints in autoregressive text generation models, which we refer to as GeLaTo (Generating Language with Tractable Constraints).

  • Our goal is to generate text effectively following the conditional distribution

    PrLM(x1:n|α) for arbitrary lexical constraints α.

    • TPMs can efficiently compute the joint probability distribution over the input sequence and the constraints, which allows for more precise control over the generation process.
    • Pre-trained LMs only model the next token distribution given some prefix, and conditioning on constraints can be intractable even for simple constraints.
  • We use distilled hidden Markov models

    1. We can efficiently compute
      Pr(text|α)
      , to guide autoregressive generation from GPT2.
    2. We propose a dynamic programming algorithm that efficiently computes conditional probabilities
      PrHMM(·|α)
  • Our study demonstrates the potential of TPMs in controlling large language models and motivates the development of more expressive TPMs.

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  1. We train a TPM
    PrTPM
    via maximum likelihood estimation (MLE) on samples drawn from
    PrLM
    , which is equivalent to minimizing the KL-divergence between
    PrTPM
    and
    PrLM
    ;
  2. At generation time, we compute
    PrTPM(xt+1|x1:t,α)
    efficiently and combine it with
    PrLM(xt+1|x1:t)
    to approximate
    PrLM(xt+1|x1:t,α)
    for reliable control.

Related work

Tractable probabilistic models

  • A class of queries
    Q
    is tractable on a family of probabilistic models
    M
    iff any query
    qQ
    on a model
    mM
    can be computed in time
    O(
    poly
    (|m|))
    .
  • We also say that
    M
    is a tractable model for
    Q
    .
  • Tractable probabilistic models support efficient probabilistic inference.
  • Probabilistic circuits (PCs) is a unified framework for a large family of tractable probabilistic models:
    • hidden Markov models
    • bounded tree-width graphical models
    • sum-product networks (SPNs)

Controllable Autoregressive Language Generation

  • One line of research on constrained text generation focuses on modifying the decoding algorithm to inject constraints into the beam search process
    • Search-based
      • constrained beam search
      • NeuroLogic Decoding
      • A*esque NeuroLogic Decoding
    • Token-level
      • NADO
      • FUDGE
    • Insertion-based

Guiding Autoregressive Generation with Tractable Probabilistic Models

  • Our goal is to generate from the following conditional distribution:
    PrLM(x1:nα)=tPrLM(xt+1x1:t,α)
    • PrLM(xt+1|x1:t,α)
      is intractable
    • We can assume that
      PrTPM(xt+1|x1:t,α)
      can be efficiently computed.
  • We train the TPM model via MLE:
    Ex1:nPrLMlogPrTPM(x1:n)
  • Which effectively minimizes their KL-divergence:
    DKL(PrLMPrTPM)=Ex1:nPrLMlogPrLM(x1:n)Ex1:nPrLMlogPrTPM(x1:n)
  • We assume that there exists some “quality” constraint
    β
    such that
    PrTPM(|β)
    is even closer to
    PrLM
    .
    PrTPM(x1:nα,β)=tPrTPM(xt+1x1:t,α,β)
  • We assume the key independence assumption:
    PrTPM(xt+1x1:t,α,β)PrTPM(αx1:t+1,β)PrTPM(xt+1x1:t,β)PrTPM(αx1:t+1)PrLM(xt+1x1:t).
    • Unsupervised setting

      • Assume that the base pre-trained LM is not fine-tuned given task-specific supervision.
      • It may still be adapted to generate text in a specific domain or context.
        p(xt+1x1:t,α)PrTPM(αx1:t+1)PrLM(xt+1x1:t).
    • Supervised setting

      • Assume that
        PrLM
        is fine-tuned in a sequence-tosequence manner.
      • We adopt an alternative formulation by viewing
        PrTPM(xt+1|x1:t,α)
        and
        PrLM(xt+1|x1:t)
        as classifiers trained for the same task yet with different biases.
        p(xt+1x1:t,α)PrTPM(xt+1x1:t,α)wPrLM(xt+1x1:t)1w
  • To summarize, GeLaTo consists of two major steps:
    • Distillation - We train a TPM on samples drawn from the pretrained LM via MLE to effectively minimize the KL divergence between
      PrLM
      and
      PrTPM
      .
    • Probabilistic reasoning: for each step of autoregressive generation, we compute
      PrTPM(·|α)
      and generate from the conditional next-token distribution
      p(xt+1|x1:t,α)
      defined above.
  • Two advantages:
    • The sentences generated following
      p(xt+1|x1:t,α)
      are guaranteed to satisfy the lexical constraint α.
    • The TPM training is independent of the lexical constraint α, which is only enforced at inference time.
      • No need to re-train the TPM model no matter how α changes.

Efficient Probabilistic Reasoning with Hidden Markov Models(HMMs)

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • We need to compute

    PrTPM(x1:t,α):

    • unsupervised setting:
      Pr(α|x1:t+1)
      =
      Pr(x1:t+1,α)/Pr(x1:t+1)
    • supervised setting:
      Pr(xt+1|x1:t,α)Pr(x1:t+1,α)
  • We describe a dynamic programming algorithm that computes

    Pr(x1:t,α) for HMMs, where α is some lexical constraint encoded in a conjunctive normal form (CNF):
    (I(w1,1)I(w1,d1))(I(wm,1)I(wm,dm))

    • wi,j
      is a string of tokens.
    • I(wij)
      is the indicator variable that represents whether
      wij
      appears in the generated text.

Hidden Markov Models

  • The joint probability

    Pr(x1:n,z1:n) is defined as:
    Pr(x1z1)Pr(z1)2tnPr(xtzt)Pr(ztzt1)

  • The parameters of HMM are given by the initial probability

    Pr(z1), emission matrix
    Pr(xt|zt)
    and the transition matrix
    Pr(zt+1|zt)
    , which stay the same across different positions t.
    Pr(xt:nzt,x1:t1)=Pr(xt:nzt).

  • forward algorithm:

    Pr(x1:t,zt)=1zt1hPr(xtzt)Pr(ztzt1)Pr(xt1,zt1)

  • PrHMM(x1:n) effectively defines a distribution over all texts with length ≤ n.

An Efficient Dynamic Programming Algorithm

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • α
    is some CNF formula obtained by removing from the original
    α
    the clauses that are already satisfied.
  • xl:r
    is either the empty string or a suffix for some keystring in
    α
    ,
  • ψ
    is a CNF consisting of a subset of clauses in
    α
    and
    zl
    is a latent state for
    Zl
    .
  • S(x,α):={s:x a suffix of x s.t. xs lies in α}
  • Case 1.
    xl:r
    ; then
    Pr(xl:r,αl:nzl)=zr+1Pr(xl:r,zr+1zl)(Pr(αr+1:nzr+1)+sS(xl:r,α)Pr(sr+1:r+|s|,(αxl:rs)r+1:nzr+1)sS(xl:r,α)Pr(sr+1:r+|s|,αr+1:nzr+1));
  • Case 2.
    xl:r
    =
    ; we reduce the problem to Case 1 by enumerating
    xl
    over the vocabulary:
    Pr(αl:nzl)=xl vocabulary Pr(xl,αl:nzl)
  • At step t by computing
    Pr(x1:t1,xt,α1:n)
    , where
    x1:t1
    denotes the first
    t1
    tokens that have been generated:
    Pr(x1:t,α1:n)=z1Pr(z1)Pr(x1:t,α1:nz1)
  • the time complexity of GeLaTo is O(2|α|nm)
    • |α| is the number of clauses in α
    • n is the maximum sequence length
    • m is the number of different suffixes for all keystrings in α.

Experiments

  1. Fine-tuning GPT2-large

    • domain adaptation
    • sequence-to-sequence
  2. Training HMMs

    • To enforce lexical constraint in autoregressive. generation
  3. Constraint Formulation

    [I( catch )I( caught )][I( fr  is  bee )I( fr  is  bees )][I( snow )I( snow  ing )I( snow  ed )]

  4. Decoding

    • We adopt beam search to greedily search for
      x1:n
      that maximizes
      p(x1:n|α)
      .
  5. Metrics

    • ROUGE
    • BLEU
    • CIDEr
    • SPICE

Conclusion

  • We propose GeLaTo, where we use tractable probabilistic models (TPMs) to impose complex lexical constraints (denoted α) in autoregressive language generation from large language models.
  • With hidden Markov model as a running example:
    • We present an efficient dynamic programming algorithm for conditioning HMMs on complex lexical constraints.
    • We demonstrate the effectiveness of GeLaTo on various constrained generation benchmarks.

Appendix

Autoregressive model

  • An autoregressive language model is a type of Machine Learning model that uses autoregressive techniques to predict the next word in a sequence of words based on the words that have come before it.
    • y(t)=c+w_1y(t1)+w_2y(t2)++w_py(tp)+e(t)

HMMs

  • A Hidden Markov Model (HMM) is a statistical model used to describe a sequence of observable events or symbols in terms of an underlying sequence of hidden states.
  • Given a sequence of observations, the goal of HMMs is to find the most likely sequence of hidden states that generated those observations.