# [Tractable Control for Autoregressive Language Generation](https://openreview.net/attachment?id=ET6qkbzeOx&name=pdf) *2023/08/22* ###### tags: `RL Group meeting` # Outline * Introduction * Related work * Guiding Autoregressive Generation with Tractable Probabilistic Models * Efficient Probabilistic Reasoning with Hidden Markov Models * Experiments * Conclusion # Introduction * Autoregressive large language models remain a major challenge to generate text that satisfies complex constraints: * Sampling from the conditional distribution $\operatorname{Pr}(\operatorname{text} | α)$ is intractable for even the simplest lexical constraints $α$. * We propose to use tractable probabilistic models (TPMs) to impose lexical constraints in autoregressive text generation models, which we refer to as **GeLaTo** (Generating Language with Tractable Constraints). * Our goal is to **generate text effectively** following the conditional distribution $\operatorname{Pr_{LM}}(x_{1:n} | α)$ for arbitrary lexical constraints α. * TPMs can efficiently compute the joint probability distribution over the input sequence and the constraints, which allows for more precise control over the generation process. * Pre-trained LMs only model the next token distribution given some prefix, and conditioning on constraints can be intractable even for simple constraints. * We use distilled **hidden Markov models** 1. We can efficiently compute $\operatorname{Pr}(\operatorname{text} | α)$, to guide autoregressive generation from GPT2. 2. We propose a dynamic programming algorithm that efficiently computes conditional probabilities $\operatorname{Pr_{HMM}}(· | α)$ * Our study demonstrates the potential of TPMs in controlling large language models and motivates the development of more expressive TPMs. ![](https://hackmd.io/_uploads/BJvVxX5nh.png) 1. We train a TPM $\operatorname{Pr_{TPM}}$ via maximum likelihood estimation (MLE) on samples drawn from $\operatorname{Pr_{LM}}$, which is equivalent to minimizing the KL-divergence between $\operatorname{Pr_{TPM}}$ and $\operatorname{Pr_{LM}}$; 2. At generation time, we compute $\operatorname{Pr_{TPM}}(x_{t+1} | x_{1:t}, α)$ efficiently and combine it with $\operatorname{Pr_{LM}}(x_{t+1} | x_{1:t})$ to approximate $\operatorname{Pr_{LM}}(x_{t+1} | x_{1:t}, α)$ for reliable control. # Related work ## Tractable probabilistic models * A class of queries $\mathbf{Q}$ is tractable on a family of probabilistic models $\mathcal{M}$ iff any query $q \in \mathbf{Q}$ on a model $m \in \mathcal{M}$ can be computed in time $\mathcal{O}($ poly $(|m|))$. * We also say that $\mathcal{M}$ is a tractable model for $\mathbf{Q}$. * Tractable probabilistic models support efficient probabilistic inference. * Probabilistic circuits (PCs) is a unified framework for a large family of tractable probabilistic models: * hidden Markov models * bounded tree-width graphical models * sum-product networks (SPNs) ## Controllable Autoregressive Language Generation * One line of research on constrained text generation focuses on modifying the decoding algorithm to inject constraints into the beam search process * Search-based * constrained beam search * NeuroLogic Decoding * A*esque NeuroLogic Decoding * Token-level * NADO * FUDGE * Insertion-based # Guiding Autoregressive Generation with Tractable Probabilistic Models * Our goal is to generate from the following conditional distribution: \begin{equation} \operatorname{Pr}_{\mathrm{LM}}\left(x_{1: n} \mid \alpha\right)=\prod_t \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}, \alpha\right) \end{equation} * $\operatorname{Pr}_{\operatorname{LM}}(x_{t+1} | x_{1:t}, α)$ is intractable * We can assume that $\operatorname{Pr_{TPM}}(x_{t+1} | x_{1:t}, α)$ can be efficiently computed. * We train the TPM model via MLE: $$ \mathbb{E}_{x_{1: n} \sim \operatorname{Pr}_{\mathrm{LM}}} \log \operatorname{Pr}_{\mathrm{TPM}}\left(x_{1: n}\right) $$ * Which effectively minimizes their KL-divergence: $$ \begin{aligned} & D_{\mathrm{KL}}\left(\operatorname{Pr}_{\mathrm{LM}} \| \operatorname{Pr}_{\mathrm{TPM}}\right) \\ & =\mathbb{E}_{x_{1: n} \sim \operatorname{Pr}_{\mathrm{LM}}} \log \operatorname{Pr}_{\mathrm{LM}}\left(x_{1: n}\right)-\mathbb{E}_{x_{1: n} \sim \operatorname{Pr}_{\mathrm{LM}}} \log \operatorname{Pr}_{\mathrm{TPM}}\left(x_{1: n}\right) \end{aligned} $$ * We assume that there exists some “quality” constraint $β$ such that $\operatorname{Pr_{TPM}}( | β)$ is even closer to $\operatorname{Pr_{LM}}$. $$ \operatorname{Pr}_{\mathrm{TPM}}\left(x_{1: n} \mid \alpha, \beta\right)=\prod_t \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \alpha, \beta\right) $$ * We assume the key independence assumption: $$ \begin{aligned} & \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \alpha, \beta\right) \\ & \quad \propto \operatorname{Pr}_{\mathrm{TPM}}\left(\alpha \mid x_{1: t+1}, \beta\right) \cdot \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \beta\right) \\ & \quad \propto \operatorname{Pr}_{\mathrm{TPM}}\left(\alpha \mid x_{1: t+1}\right) \cdot \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}\right) . \end{aligned} $$ * Unsupervised setting * Assume that the base pre-trained LM is not fine-tuned given task-specific supervision. * It may still be adapted to generate text in a specific domain or context. $$ p\left(x_{t+1} \mid x_{1: t}, \alpha\right) \propto \operatorname{Pr}_{\mathrm{TPM}}\left(\alpha \mid x_{1: t+1}\right) \cdot \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}\right) . $$ * Supervised setting * Assume that $\operatorname{Pr_{LM}}$ is fine-tuned in a sequence-tosequence manner. * We adopt an alternative formulation by viewing $\operatorname{Pr_{TPM}}(x_{t+1} | x_{1:t}, α)$ and $\operatorname{Pr_{LM}}(x_{t+1} | x_{1:t})$ as **classifiers** trained for the same task yet with different biases. $$ \begin{aligned} & p\left(x_{t+1} \mid x_{1: t}, \alpha\right) \\ & \quad \propto \operatorname{Pr}_{\mathrm{TPM}}\left(x_{t+1} \mid x_{1: t}, \alpha\right)^w \cdot \operatorname{Pr}_{\mathrm{LM}}\left(x_{t+1} \mid x_{1: t}\right)^{1-w} \end{aligned} $$ * To summarize, GeLaTo consists of two major steps: * Distillation - We train a TPM on samples drawn from the pretrained LM via MLE to effectively minimize the KL divergence between $\operatorname{Pr_{LM}}$ and $\operatorname{Pr_{TPM}}$. * Probabilistic reasoning: for each step of autoregressive generation, we compute $\operatorname{Pr_{TPM}}(· | α)$ and generate from the conditional next-token distribution $p(x_{t+1} | x_{1:t}, α)$ defined above. * Two advantages: * The sentences generated following $p(x_{t+1} |x_{1:t}, α)$ are guaranteed to satisfy the lexical constraint α. * The TPM training is independent of the lexical constraint α, which is only enforced at inference time. * No need to re-train the TPM model no matter how α changes. # Efficient Probabilistic Reasoning with Hidden Markov Models(HMMs) ![](https://hackmd.io/_uploads/HyuTPsn32.png) * We need to compute $\operatorname{Pr_{TPM}}(x_{1:t}, α)$: * unsupervised setting: $\operatorname{Pr}(α | x_{1:t+1})$ = $\operatorname{Pr}(x_{1:t+1}, α)/\operatorname{Pr}(x_{1:t+1})$ * supervised setting: $\operatorname{Pr}(x_{t+1} | x_{1:t}, α) ∝ \operatorname{Pr}(x_{1:t+1}, α)$ * We describe a **dynamic programming algorithm** that computes $\operatorname{Pr}(x_{1:t}, α)$ for HMMs, where **α** is some lexical constraint encoded in a conjunctive normal form (CNF): $$ \left(I\left(w_{1,1}\right) \vee \cdots \vee I\left(w_{1, d_1}\right)\right) \wedge \cdots \wedge\left(I\left(w_{m, 1}\right) \vee \cdots \vee I\left(w_{m, d_m}\right)\right) $$ * $w_{i,j}$ is a string of tokens. * $I(w_{ij} )$ is the indicator variable that represents whether $w_{ij}$ appears in the generated text. ## Hidden Markov Models * The joint probability $\operatorname{Pr}(x_{1:n}, z_{1:n})$ is defined as: $$ \operatorname{Pr}\left(x_1 \mid z_1\right) \operatorname{Pr}\left(z_1\right) \prod_{2 \leq t \leq n} \operatorname{Pr}\left(x_t \mid z_t\right) \operatorname{Pr}\left(z_t \mid z_{t-1}\right) $$ * The parameters of HMM are given by the initial probability $\operatorname{Pr}(z_1)$, emission matrix $\operatorname{Pr}(x_t | z_t)$ and the transition matrix $\operatorname{Pr}(z_{t+1} | z_t)$, which stay the same across different positions t. $$ \operatorname{Pr}\left(x_{t: n} \mid z_t, x_{1: t-1}\right)=\operatorname{Pr}\left(x_{t: n} \mid z_t\right) . $$ * forward algorithm: $$ \begin{aligned} & \operatorname{Pr}\left(x_{1: t}, z_t\right) & =\sum_{1 \leq z_{t-1} \leq h} \operatorname{Pr}\left(x_t \mid z_t\right) \operatorname{Pr}\left(z_t \mid z_{t-1}\right) \operatorname{Pr}\left(x_{t-1}, z_{t-1}\right) \end{aligned} $$ * $\operatorname{Pr_{HMM}}(x_{1:n})$ effectively defines a distribution over all texts with length ≤ n. ## An Efficient Dynamic Programming Algorithm ![](https://hackmd.io/_uploads/B1J2Ps3hn.png) * $α′$ is some CNF formula obtained by removing from the original $α$ the clauses that are already satisfied. * $x_{l:r}$ is either the empty string or a suffix for some keystring in $α$, * $ψ$ is a CNF consisting of a subset of clauses in $α$ and $z_l$ is a latent state for $Z_l$. * $$S(x, \alpha):=\left\{s: \exists x^{\prime} \text { a suffix of } x \text { s.t. } x^{\prime} \oplus s \text { lies in } \alpha\right\}$$ * Case 1. $x_{l:r}$ $\neq ∅$; then $$ \begin{aligned} & \operatorname{Pr}\left(x_{l: r}, \alpha_{l: n} \mid z_l\right) \\ & =\sum_{z_{r+1}} \underline{\operatorname{Pr}\left(x_{l: r}, z_{r+1} \mid z_l\right)}{}\left(\operatorname{Pr}\left(\alpha_{r+1: n} \mid z_{r+1}\right)\right. \\ & +\sum_{s \in S\left(x_{l: r}, \alpha\right)} \operatorname{Pr}\left(s_{r+1: r+|s|},\left(\alpha \backslash x_{l: r} \oplus s\right)_{r+1: n} \mid z_{r+1}\right) \\ & \left.-\sum_{s \in S\left(x_{l: r}, \alpha\right)} \operatorname{Pr}\left(s_{r+1: r+|s|}, \alpha_{r+1: n} \mid z_{r+1}\right)\right) ; \end{aligned} $$ * Case 2. $x_{l:r}$ $= ∅$; we reduce the problem to Case 1 by enumerating $x_l$ over the vocabulary: $$ \operatorname{Pr}\left(\alpha_{l: n} \mid z_l\right)=\sum_{x_l \in \text { vocabulary }} \operatorname{Pr}\left(x_l, \alpha_{l: n} \mid z_l\right) $$ * At step t by computing $\operatorname{Pr}(x_{1:t−1}, x_t, α_{1:n})$, where $x_{1:t−1}$ denotes the first $t − 1$ tokens that have been generated: $$ \operatorname{Pr}\left(x_{1: t}, \alpha_{1: n}\right)=\sum_{z_1} \operatorname{Pr}\left(z_1\right) \operatorname{Pr}\left(x_{1: t}, \alpha_{1: n} \mid z_1\right) $$ * the time complexity of GeLaTo is O(2|α|nm) * |α| is the number of clauses in α * n is the maximum sequence length * m is the number of different suffixes for all keystrings in α. # Experiments 1. Fine-tuning GPT2-large * domain adaptation * sequence-to-sequence 2. Training HMMs * To enforce lexical constraint in autoregressive. generation 3. Constraint Formulation $$ \begin{aligned} & {[I(\text { catch }) \vee I(\text { caught }) \vee \ldots] } \\ \wedge & {[I(\text { fr } \oplus \text { is } \oplus \text { bee }) \vee I(\text { fr } \oplus \text { is } \oplus \text { bees }) \vee \ldots] } \\ \wedge & {[I(\text { snow }) \vee I(\text { snow } \oplus \text { ing }) \vee I(\text { snow } \oplus \text { ed }) \vee \ldots] } \end{aligned} $$ 4. Decoding * We adopt beam search to greedily search for $x_{1:n}$ that maximizes $p(x_{1:n} | α)$. 5. Metrics * ROUGE * BLEU * CIDEr * SPICE ![](https://hackmd.io/_uploads/S1TP3hnn2.png) # Conclusion * We propose GeLaTo, where we use tractable probabilistic models (TPMs) to impose complex lexical constraints (denoted α) in autoregressive language generation from large language models. * With hidden Markov model as a running example: * We present an efficient **dynamic programming algorithm** for conditioning HMMs on complex lexical constraints. * We demonstrate the effectiveness of GeLaTo on various constrained generation benchmarks. # Appendix ## [Autoregressive model](https://deepchecks.com/glossary/autoregressive-model/) * An autoregressive language model is a type of Machine Learning model that uses **autoregressive techniques** to predict the next word in a sequence of words based on the words that have come before it. * $$y(t)=c+w \_1 y(t-1)+w \_2 y(t-2)+\ldots+w \_p y(t-p)+e(t)$$ ## [HMMs](https://wisdomml.in/hidden-markov-model-hmm-in-nlp-python/) * A Hidden Markov Model (HMM) is a statistical model used to describe a sequence of observable events or symbols in terms of an underlying sequence of hidden states. * Given a sequence of observations, the goal of HMMs is to find the most likely sequence of hidden states that generated those observations. ![](https://hackmd.io/_uploads/SyUG_yChn.png)