# [Tractable Control for Autoregressive Language Generation](https://openreview.net/attachment?id=ET6qkbzeOx&name=pdf)
*2023/08/22*
###### tags: `RL Group meeting`
# Outline
* Introduction
* Related work
* Guiding Autoregressive Generation with Tractable Probabilistic Models
* Efficient Probabilistic Reasoning with Hidden Markov Models
* Experiments
* Conclusion
# Introduction
* Autoregressive large language models remain a major challenge to generate text that satisfies complex constraints:
* Sampling from the conditional distribution $\operatorname{Pr}(\operatorname{text} | α)$ is intractable for even the simplest lexical constraints $α$.
* We propose to use tractable probabilistic models (TPMs) to impose lexical constraints in autoregressive text generation models, which we refer to as **GeLaTo** (Generating Language with Tractable Constraints).
* Our goal is to **generate text effectively** following the conditional distribution $\operatorname{Pr_{LM}}(x_{1:n} | α)$ for arbitrary lexical constraints α.
* TPMs can efficiently compute the joint probability distribution over the input sequence and the constraints, which allows for more precise control over the generation process.
* Pre-trained LMs only model the next token distribution given some prefix, and conditioning on constraints can be intractable even for simple constraints.
* We use distilled **hidden Markov models**
1. We can efficiently compute $\operatorname{Pr}(\operatorname{text} | α)$, to guide autoregressive generation from GPT2.
2. We propose a dynamic programming algorithm that efficiently computes conditional probabilities $\operatorname{Pr_{HMM}}(· | α)$
* Our study demonstrates the potential of TPMs in controlling large language models and motivates the development of more expressive TPMs.

1. We train a TPM $\operatorname{Pr_{TPM}}$ via maximum likelihood estimation (MLE) on samples drawn from $\operatorname{Pr_{LM}}$, which is equivalent to minimizing the KL-divergence between $\operatorname{Pr_{TPM}}$ and $\operatorname{Pr_{LM}}$;
2. At generation time, we compute $\operatorname{Pr_{TPM}}(x_{t+1} | x_{1:t}, α)$ efficiently and combine it with $\operatorname{Pr_{LM}}(x_{t+1} | x_{1:t})$ to approximate $\operatorname{Pr_{LM}}(x_{t+1} | x_{1:t}, α)$ for reliable control.
# Related work
## Tractable probabilistic models
* A class of queries $\mathbf{Q}$ is tractable on a family of probabilistic models $\mathcal{M}$ iff any query $q \in \mathbf{Q}$ on a model $m \in \mathcal{M}$ can be computed in time $\mathcal{O}($ poly $(|m|))$.
* We also say that $\mathcal{M}$ is a tractable model for $\mathbf{Q}$.
* Tractable probabilistic models support efficient probabilistic inference.
* Probabilistic circuits (PCs) is a unified framework for a large family of tractable probabilistic models:
* hidden Markov models
* bounded tree-width graphical models
* sum-product networks (SPNs)
## Controllable Autoregressive Language Generation
* One line of research on constrained text generation focuses on modifying the decoding algorithm to inject constraints into the beam search process
* Search-based
* constrained beam search
* NeuroLogic Decoding
* A*esque NeuroLogic Decoding
* Token-level
* NADO
* FUDGE
* Insertion-based
# Guiding Autoregressive Generation with Tractable Probabilistic Models
* Our goal is to generate from the following conditional distribution:
\begin{equation} \operatorname{Pr}_{\mathrm{LM}}\left(x_{1: n} \mid \alpha\right)=\prod_t \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}, \alpha\right) \end{equation}
* $\operatorname{Pr}_{\operatorname{LM}}(x_{t+1} | x_{1:t}, α)$ is intractable
* We can assume that $\operatorname{Pr_{TPM}}(x_{t+1} | x_{1:t}, α)$ can be efficiently computed.
* We train the TPM model via MLE:
$$
\mathbb{E}_{x_{1: n} \sim \operatorname{Pr}_{\mathrm{LM}}} \log \operatorname{Pr}_{\mathrm{TPM}}\left(x_{1: n}\right)
$$
* Which effectively minimizes their KL-divergence:
$$
\begin{aligned}
& D_{\mathrm{KL}}\left(\operatorname{Pr}_{\mathrm{LM}} \| \operatorname{Pr}_{\mathrm{TPM}}\right) \\
& =\mathbb{E}_{x_{1: n} \sim \operatorname{Pr}_{\mathrm{LM}}} \log \operatorname{Pr}_{\mathrm{LM}}\left(x_{1: n}\right)-\mathbb{E}_{x_{1: n} \sim \operatorname{Pr}_{\mathrm{LM}}} \log \operatorname{Pr}_{\mathrm{TPM}}\left(x_{1: n}\right)
\end{aligned}
$$
* We assume that there exists some “quality” constraint $β$ such that $\operatorname{Pr_{TPM}}( | β)$ is even closer to $\operatorname{Pr_{LM}}$.
$$
\operatorname{Pr}_{\mathrm{TPM}}\left(x_{1: n} \mid \alpha, \beta\right)=\prod_t \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \alpha, \beta\right)
$$
* We assume the key independence assumption:
$$
\begin{aligned}
& \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \alpha, \beta\right) \\
& \quad \propto \operatorname{Pr}_{\mathrm{TPM}}\left(\alpha \mid x_{1: t+1}, \beta\right) \cdot \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \beta\right) \\
& \quad \propto \operatorname{Pr}_{\mathrm{TPM}}\left(\alpha \mid x_{1: t+1}\right) \cdot \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}\right) .
\end{aligned}
$$
* Unsupervised setting
* Assume that the base pre-trained LM is not fine-tuned given task-specific supervision.
* It may still be adapted to generate text in a specific domain or context.
$$
p\left(x_{t+1} \mid x_{1: t}, \alpha\right) \propto \operatorname{Pr}_{\mathrm{TPM}}\left(\alpha \mid x_{1: t+1}\right) \cdot \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}\right) .
$$
* Supervised setting
* Assume that $\operatorname{Pr_{LM}}$ is fine-tuned in a sequence-tosequence manner.
* We adopt an alternative formulation by viewing $\operatorname{Pr_{TPM}}(x_{t+1} | x_{1:t}, α)$ and $\operatorname{Pr_{LM}}(x_{t+1} | x_{1:t})$ as **classifiers** trained for the same task yet with different biases.
$$
\begin{aligned}
& p\left(x_{t+1} \mid x_{1: t}, \alpha\right) \\
& \quad \propto \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \alpha\right)^w \cdot \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}\right)^{1-w}
\end{aligned}
$$
* To summarize, GeLaTo consists of two major steps:
* Distillation - We train a TPM on samples drawn from the pretrained LM via MLE to effectively minimize the KL divergence between $\operatorname{Pr_{LM}}$ and $\operatorname{Pr_{TPM}}$.
* Probabilistic reasoning: for each step of autoregressive generation, we compute $\operatorname{Pr_{TPM}}(· | α)$ and generate from the conditional next-token distribution $p(x_{t+1} | x_{1:t}, α)$ defined above.
* Two advantages:
* The sentences generated following $p(x_{t+1} |x_{1:t}, α)$ are guaranteed to satisfy the lexical constraint α.
* The TPM training is independent of the lexical constraint α, which is only enforced at inference time.
* No need to re-train the TPM model no matter how α changes.
# Efficient Probabilistic Reasoning with Hidden Markov Models(HMMs)

* We need to compute $\operatorname{Pr_{TPM}}(x_{1:t}, α)$:
* unsupervised setting: $\operatorname{Pr}(α | x_{1:t+1})$ = $\operatorname{Pr}(x_{1:t+1}, α)/\operatorname{Pr}(x_{1:t+1})$
* supervised setting: $\operatorname{Pr}(x_{t+1} | x_{1:t}, α) ∝ \operatorname{Pr}(x_{1:t+1}, α)$
* We describe a **dynamic programming algorithm** that computes $\operatorname{Pr}(x_{1:t}, α)$ for HMMs, where **α** is some lexical constraint encoded in a conjunctive normal form (CNF):
$$
\left(I\left(w_{1,1}\right) \vee \cdots \vee I\left(w_{1, d_1}\right)\right) \wedge \cdots \wedge\left(I\left(w_{m, 1}\right) \vee \cdots \vee I\left(w_{m, d_m}\right)\right)
$$
* $w_{i,j}$ is a string of tokens.
* $I(w_{ij} )$ is the indicator variable that represents whether $w_{ij}$ appears in the generated text.
## Hidden Markov Models
* The joint probability $\operatorname{Pr}(x_{1:n}, z_{1:n})$ is defined as:
$$
\operatorname{Pr}\left(x_1 \mid z_1\right) \operatorname{Pr}\left(z_1\right) \prod_{2 \leq t \leq n} \operatorname{Pr}\left(x_t \mid z_t\right) \operatorname{Pr}\left(z_t \mid z_{t-1}\right)
$$
* The parameters of HMM are given by the initial probability $\operatorname{Pr}(z_1)$, emission matrix $\operatorname{Pr}(x_t | z_t)$ and the transition matrix $\operatorname{Pr}(z_{t+1} | z_t)$, which stay the same across different positions t.
$$
\operatorname{Pr}\left(x_{t: n} \mid z_t, x_{1: t-1}\right)=\operatorname{Pr}\left(x_{t: n} \mid z_t\right) .
$$
* forward algorithm:
$$
\begin{aligned}
& \operatorname{Pr}\left(x_{1: t}, z_t\right)
& =\sum_{1 \leq z_{t-1} \leq h} \operatorname{Pr}\left(x_t \mid z_t\right) \operatorname{Pr}\left(z_t \mid z_{t-1}\right) \operatorname{Pr}\left(x_{t-1}, z_{t-1}\right)
\end{aligned}
$$
* $\operatorname{Pr_{HMM}}(x_{1:n})$ effectively defines a distribution over all texts with length ≤ n.
## An Efficient Dynamic Programming Algorithm

* $α′$ is some CNF formula obtained by removing from the original $α$ the clauses that are already satisfied.
* $x_{l:r}$ is either the empty string or a suffix for some keystring in $α$,
* $ψ$ is a CNF consisting of a subset of clauses in $α$ and $z_l$ is a latent state for $Z_l$.
* $$S(x, \alpha):=\left\{s: \exists x^{\prime} \text { a suffix of } x \text { s.t. } x^{\prime} \oplus s \text { lies in } \alpha\right\}$$
* Case 1. $x_{l:r}$ $\neq ∅$; then
$$
\begin{aligned}
& \operatorname{Pr}\left(x_{l: r}, \alpha_{l: n} \mid z_l\right) \\
& =\sum_{z_{r+1}} \underline{\operatorname{Pr}\left(x_{l: r}, z_{r+1} \mid z_l\right)}{}\left(\operatorname{Pr}\left(\alpha_{r+1: n} \mid z_{r+1}\right)\right. \\
& +\sum_{s \in S\left(x_{l: r}, \alpha\right)} \operatorname{Pr}\left(s_{r+1: r+|s|},\left(\alpha \backslash x_{l: r} \oplus s\right)_{r+1: n} \mid z_{r+1}\right) \\
& \left.-\sum_{s \in S\left(x_{l: r}, \alpha\right)} \operatorname{Pr}\left(s_{r+1: r+|s|}, \alpha_{r+1: n} \mid z_{r+1}\right)\right) ;
\end{aligned}
$$
* Case 2. $x_{l:r}$ $= ∅$; we reduce the problem to Case 1 by enumerating $x_l$ over the vocabulary:
$$
\operatorname{Pr}\left(\alpha_{l: n} \mid z_l\right)=\sum_{x_l \in \text { vocabulary }} \operatorname{Pr}\left(x_l, \alpha_{l: n} \mid z_l\right)
$$
* At step t by computing $\operatorname{Pr}(x_{1:t−1}, x_t, α_{1:n})$, where $x_{1:t−1}$ denotes the first $t − 1$ tokens that have been generated:
$$
\operatorname{Pr}\left(x_{1: t}, \alpha_{1: n}\right)=\sum_{z_1} \operatorname{Pr}\left(z_1\right) \operatorname{Pr}\left(x_{1: t}, \alpha_{1: n} \mid z_1\right)
$$
* the time complexity of GeLaTo is O(2|α|nm)
* |α| is the number of clauses in α
* n is the maximum sequence length
* m is the number of different suffixes for all keystrings in α.
# Experiments
1. Fine-tuning GPT2-large
* domain adaptation
* sequence-to-sequence
2. Training HMMs
* To enforce lexical constraint in autoregressive. generation
3. Constraint Formulation
$$
\begin{aligned}
& {[I(\text { catch }) \vee I(\text { caught }) \vee \ldots] } \\
\wedge & {[I(\text { fr } \oplus \text { is } \oplus \text { bee }) \vee I(\text { fr } \oplus \text { is } \oplus \text { bees }) \vee \ldots] } \\
\wedge & {[I(\text { snow }) \vee I(\text { snow } \oplus \text { ing }) \vee I(\text { snow } \oplus \text { ed }) \vee \ldots] }
\end{aligned}
$$
4. Decoding
* We adopt beam search to greedily search for $x_{1:n}$ that maximizes $p(x_{1:n} | α)$.
5. Metrics
* ROUGE
* BLEU
* CIDEr
* SPICE

# Conclusion
* We propose GeLaTo, where we use tractable probabilistic models (TPMs) to impose complex lexical constraints (denoted α) in autoregressive language generation from large language models.
* With hidden Markov model as a running example:
* We present an efficient **dynamic programming algorithm** for conditioning HMMs on complex lexical constraints.
* We demonstrate the effectiveness of GeLaTo on various constrained generation benchmarks.
# Appendix
## [Autoregressive model](https://deepchecks.com/glossary/autoregressive-model/)
* An autoregressive language model is a type of Machine Learning model that uses **autoregressive techniques** to predict the next word in a sequence of words based on the words that have come before it.
* $$y(t)=c+w \_1 y(t-1)+w \_2 y(t-2)+\ldots+w \_p y(t-p)+e(t)$$
## [HMMs](https://wisdomml.in/hidden-markov-model-hmm-in-nlp-python/)
* A Hidden Markov Model (HMM) is a statistical model used to describe a sequence of observable events or symbols in terms of an underlying sequence of hidden states.
* Given a sequence of observations, the goal of HMMs is to find the most likely sequence of hidden states that generated those observations.
