2023/08/22
Outline
- Introduction
- Related work
- Guiding Autoregressive Generation with Tractable Probabilistic Models
- Efficient Probabilistic Reasoning with Hidden Markov Models
- Experiments
- Conclusion
Introduction
-
Autoregressive large language models remain a major challenge to generate text that satisfies complex constraints:
- Sampling from the conditional distribution is intractable for even the simplest lexical constraints .
-
We propose to use tractable probabilistic models (TPMs) to impose lexical constraints in autoregressive text generation models, which we refer to as GeLaTo (Generating Language with Tractable Constraints).
-
Our goal is to generate text effectively following the conditional distribution for arbitrary lexical constraints α.
- TPMs can efficiently compute the joint probability distribution over the input sequence and the constraints, which allows for more precise control over the generation process.
- Pre-trained LMs only model the next token distribution given some prefix, and conditioning on constraints can be intractable even for simple constraints.
-
We use distilled hidden Markov models
- We can efficiently compute , to guide autoregressive generation from GPT2.
- We propose a dynamic programming algorithm that efficiently computes conditional probabilities
-
Our study demonstrates the potential of TPMs in controlling large language models and motivates the development of more expressive TPMs.
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
- We train a TPM via maximum likelihood estimation (MLE) on samples drawn from , which is equivalent to minimizing the KL-divergence between and ;
- At generation time, we compute efficiently and combine it with to approximate for reliable control.
Tractable probabilistic models
- A class of queries is tractable on a family of probabilistic models iff any query on a model can be computed in time poly .
- We also say that is a tractable model for .
- Tractable probabilistic models support efficient probabilistic inference.
- Probabilistic circuits (PCs) is a unified framework for a large family of tractable probabilistic models:
- hidden Markov models
- bounded tree-width graphical models
- sum-product networks (SPNs)
Controllable Autoregressive Language Generation
- One line of research on constrained text generation focuses on modifying the decoding algorithm to inject constraints into the beam search process
- Search-based
- constrained beam search
- NeuroLogic Decoding
- A*esque NeuroLogic Decoding
- Token-level
- Insertion-based
Guiding Autoregressive Generation with Tractable Probabilistic Models
- Our goal is to generate from the following conditional distribution:
- is intractable
- We can assume that can be efficiently computed.
- We train the TPM model via MLE:
- Which effectively minimizes their KL-divergence:
- We assume that there exists some “quality” constraint such that is even closer to .
- We assume the key independence assumption:
-
Unsupervised setting
- Assume that the base pre-trained LM is not fine-tuned given task-specific supervision.
- It may still be adapted to generate text in a specific domain or context.
-
Supervised setting
- Assume that is fine-tuned in a sequence-tosequence manner.
- We adopt an alternative formulation by viewing and as classifiers trained for the same task yet with different biases.
- To summarize, GeLaTo consists of two major steps:
- Distillation - We train a TPM on samples drawn from the pretrained LM via MLE to effectively minimize the KL divergence between and .
- Probabilistic reasoning: for each step of autoregressive generation, we compute and generate from the conditional next-token distribution defined above.
- Two advantages:
- The sentences generated following are guaranteed to satisfy the lexical constraint α.
- The TPM training is independent of the lexical constraint α, which is only enforced at inference time.
- No need to re-train the TPM model no matter how α changes.
Efficient Probabilistic Reasoning with Hidden Markov Models(HMMs)
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
Hidden Markov Models
-
The joint probability is defined as:
-
The parameters of HMM are given by the initial probability , emission matrix and the transition matrix , which stay the same across different positions t.
-
forward algorithm:
-
effectively defines a distribution over all texts with length ≤ n.
An Efficient Dynamic Programming Algorithm
Image Not Showing
Possible Reasons
- The image was uploaded to a note which you don't have access to
- The note which the image was originally uploaded to has been deleted
Learn More →
- is some CNF formula obtained by removing from the original the clauses that are already satisfied.
- is either the empty string or a suffix for some keystring in ,
- is a CNF consisting of a subset of clauses in and is a latent state for .
- Case 1. ; then
- Case 2. ; we reduce the problem to Case 1 by enumerating over the vocabulary:
- At step t by computing , where denotes the first tokens that have been generated:
- the time complexity of GeLaTo is O(2|α|nm)
- |α| is the number of clauses in α
- n is the maximum sequence length
- m is the number of different suffixes for all keystrings in α.
Experiments
-
Fine-tuning GPT2-large
- domain adaptation
- sequence-to-sequence
-
Training HMMs
- To enforce lexical constraint in autoregressive. generation
-
Constraint Formulation
-
Decoding
- We adopt beam search to greedily search for that maximizes .
-
Metrics

Conclusion
- We propose GeLaTo, where we use tractable probabilistic models (TPMs) to impose complex lexical constraints (denoted α) in autoregressive language generation from large language models.
- With hidden Markov model as a running example:
- We present an efficient dynamic programming algorithm for conditioning HMMs on complex lexical constraints.
- We demonstrate the effectiveness of GeLaTo on various constrained generation benchmarks.
Appendix
- An autoregressive language model is a type of Machine Learning model that uses autoregressive techniques to predict the next word in a sequence of words based on the words that have come before it.
- A Hidden Markov Model (HMM) is a statistical model used to describe a sequence of observable events or symbols in terms of an underlying sequence of hidden states.
- Given a sequence of observations, the goal of HMMs is to find the most likely sequence of hidden states that generated those observations.
