# Training Nonlinear Transformers for Chain-of-Thought Inference: A Theoretical Generalization Analysis https://arxiv.org/abs/2410.02167 ## Abstract - Provide the theoretical study of training Transformers to obtain the CoT generalization capability on unseen tasks. - Fist, Quantify the required training samples and iterations to train a Transformer model towards CoT ability. - Then, prove the success of its CoT generalization on unseen tasks with distribution-shifted testing data. - Moreover, theoretically characterize the conditions for an accurate output by CoT even when the provided examples contain noise. ## Problem Formulation - $K$-steps reasoning tasks. - Each task $f=$ $f_K \circ \cdots \cdot f_2 \circ f_1$ is a composition of functions $\left\{f_i\right\}_{i=1}^K$ and outputs labels $\boldsymbol{z}_1, \boldsymbol{z}_2, \cdots, \boldsymbol{z}_K$ for the input $\boldsymbol{x}_{\text {query }}$. - During the $k$-th reasoning step, $k \in[K]$, the label is $\boldsymbol{z}_k=f_k\left(\boldsymbol{z}_{k-1}\right)$, where $\boldsymbol{z}_0:=\boldsymbol{x}_{\text {query }}$. ### Training #### Prompts and labels - The training prompt includes multiple $K$-steps reasoning examples and a ( $k-1$ )-step reasoning of $\boldsymbol{x}_{\text {query }}$ for any $k$ in $[K]$, and the label for this prompt is $\boldsymbol{z}_k$. - For the query input $\boldsymbol{z}_{k-1}$, in-context example of prompt $\boldsymbol{P}$ contains $l_{t r}$ reasoning examples: $$ \begin{aligned} \boldsymbol{P} & =\left(\boldsymbol{E}_1, \boldsymbol{E}_2, \cdots, \boldsymbol{E}_{l_{t r}}, \boldsymbol{Q}_k\right) \in \mathbb{R}^{2 d_{\mathcal{X}} \times\left(l_{t r} K+k\right)} \\ \text { where } \boldsymbol{E}_i & =\left(\begin{array}{cccc} \boldsymbol{x}_i & \boldsymbol{y}_{i, 1} & \cdots & \boldsymbol{y}_{i, K-1} \\ \boldsymbol{y}_{i, 1} & \boldsymbol{y}_{i, 2} & \cdots & \boldsymbol{y}_{i, K} \end{array}\right), \boldsymbol{Q}_k=\left(\begin{array}{ccccc} \boldsymbol{z}_0 & \boldsymbol{z}_1 & \cdots & \boldsymbol{z}_{k-2} & \boldsymbol{z}_{k-1} \\ \boldsymbol{z}_1 & \boldsymbol{z}_2 & \cdots & \boldsymbol{z}_{k-1} & \mathbf{0} \end{array}\right), i \in\left[l_{t r}\right] \end{aligned} $$ - Let $\boldsymbol{p}_s$ and $\boldsymbol{p}_{\text {query }}$ be the $s$-th column and the last column of $\boldsymbol{P}$, respectively, for $s \in\left[l_{t r} K+k-1\right] . \boldsymbol{x}_i, \boldsymbol{y}_{i, k}, \boldsymbol{z}_j \in \mathbb{R}^{d_{\mathcal{X}}}$ for $i \in\left[l_{t r}\right]$ and $j, k \in[K]$. - Positional encoding $\left\{\boldsymbol{c}_k\right\}_{k=1}^K \mathbb{I} \in \mathbb{R}^{2 d_{\mathcal{X}}}$, which is added to $\boldsymbol{p}$ by $\tilde{\boldsymbol{p}}_i=\boldsymbol{p}_i+\boldsymbol{c}_{(i \bmod K)}, i \in\left[K\left(l_{t r}+1\right)\right]$. #### Model - Given a prompt $\boldsymbol{P}$ with $\operatorname{len}(\boldsymbol{P})$ denoting the number of columns, it can be written as $$ F(\Psi ; \boldsymbol{P})=\sum_{i=1}^{\operatorname{len}(\boldsymbol{P})-1} \boldsymbol{W}_V \tilde{\boldsymbol{p}}_i \cdot \operatorname{softmax}\left(\left(\boldsymbol{W}_K \tilde{\boldsymbol{p}}_i\right)^{\top} \boldsymbol{W}_Q \tilde{\boldsymbol{p}}_{q u e r y}\right), $$ where $W_Q, W_K \in \mathbb{R}^{m \times\left(2 d_{\mathcal{X}}\right)}, W_V \in \mathbb{R}^{d_{\mathcal{X}} \times\left(2 d_{\mathcal{X}}\right)}$, and $\Psi:=\left\{W_Q, W_K, W_V\right\}$ is the set of all model weights. #### Loss function - The trainer use $N$ prompt and label pairs $\left\{\boldsymbol{P}^n, \boldsymbol{z}^n\right\}_{n=1}^N$ to minimize the empirical risk $$ \min _{\Psi} R_N(\Psi):=\frac{1}{N} \sum_{n=1}^N \ell\left(\Psi ; P^n, z^n\right) $$ - For the $n$-th sample, $\boldsymbol{x}_{\text {query }}^n$ and the context input $\boldsymbol{x}_i^n$ are all sampled from an unknown distribution $\mathcal{D}$, the training task $f^n$ is sampled from $\mathcal{T}, k$ is randomly selected from 1 to $K$. - The loss function is squared loss, i.e., $\ell\left(\Psi ; \boldsymbol{P}^n, \boldsymbol{z}^n\right)=1 / 2 \cdot\left\|\boldsymbol{z}^n-F\left(\Psi ; \boldsymbol{P}^n\right)\right\|^2$. ### Inference - Consider another $K$-steps reasoning task $f \in \mathcal{T}^{\prime}$, whose target is to predict labels $\left\{\boldsymbol{z}_k\right\}_{k=1}^K$ given the input query $x_{\text {query }}. \mathcal{T}^{\prime}$ is the set of testing tasks, and $\mathcal{T}^{\prime} \neq \mathcal{T}$. - The testing prompt $\boldsymbol{P}$ is composed of $l_{t s}\left(\leq l_{t r}\right)$ context examples of $K$ steps plus a query: $$ \boldsymbol{P}=\left(\boldsymbol{E}_1, \boldsymbol{E}_2, \cdots, \boldsymbol{E}_{l_{t s}}, \boldsymbol{p}_{q u e r y}\right) \in \mathbb{R}^{\left(2 d_{\mathcal{X}}\right) \times\left(l_{t s} K+1\right)}, \boldsymbol{p}_{q u e r y}=\left(\boldsymbol{x}_{q u e r y}^{\top}, \mathbf{0}^{\top}\right)^{\top} $$ - The inference process for a $K$-step CoT with $l_{t s}$ examples on a certain $f \in \mathcal{T}^{\prime}$, given the testing prompt $\boldsymbol{P}$, let $\boldsymbol{P}_1=\boldsymbol{P}$ and $\boldsymbol{P}_0$ be the first $K \cdot l_{t s}$ columns of $\boldsymbol{P}$, predicting the $k$-th step: - Generate the output $\boldsymbol{v}_k, k \in[K]$ via greedy decoding by feeding the $k$-th step prompt $\boldsymbol{P}_k$ to the trained model $\Psi$. - The greedy decoding scheme means outputting the most probable token from the discrete set $\mathcal{Y}$: $$ \boldsymbol{v}_k=\arg \min _{\boldsymbol{u} \in \mathcal{Y}} \frac{1}{2}\left\|F\left(\Psi ; \boldsymbol{P}_k\right)-\boldsymbol{u}\right\|^2, \text { (greedy decoding) } $$ - Then, use the output $\boldsymbol{v}_k$ to update $\boldsymbol{P}_k$ and use $\boldsymbol{v}_k$ as the query input to form the input prompt $\boldsymbol{P}_{k+1}$ for the next step: $$ \begin{aligned} & \boldsymbol{P}_k=\left(\boldsymbol{P}_{k-1} \boldsymbol{q}_k\right) \in \mathbb{R}^{\left(2 d_{\mathcal{X}}\right) \times\left(K l_{t s}+k\right)}, \boldsymbol{P}_{k+1}=\left(\boldsymbol{P}_k \boldsymbol{q}_{k+1}\right) \in \mathbb{R}^{\left(2 d_{\mathcal{X}}\right) \times\left(K l_{t s}+k+1\right)}, \\ & \text { where } \boldsymbol{q}_k=\left(\boldsymbol{v}_{k-1}^{\top} \boldsymbol{v}_k^{\top}\right)^{\top}, \boldsymbol{q}_{k+1}=\left(\boldsymbol{v}_k^{\top} \mathbf{0}^{\top}\right)^{\top} \end{aligned} $$ - The model finally outputs $\boldsymbol{v}_1, \cdots, \boldsymbol{v}_K$ as CoT result for query $x_{\text {query }}$. - The CoT generalization error given the testing query $\boldsymbol{x}_{\text {query }}$, the testing data distribution $\mathcal{D}^{\prime}$, and the labels $\left\{\boldsymbol{z}_k\right\}_{k=1}^K$ on a $K$-steps testing task $f \in \mathcal{T}^{\prime}$ is defined as $$ \bar{R}_{C o T, \boldsymbol{x}_{q u e r y} \sim \mathcal{D}^{\prime}, f \in \mathcal{T}^{\prime}}^f(\Psi)=\mathbb{E}_{\boldsymbol{x}_{q u e r y} \sim \mathcal{D}^{\prime}}\left[\frac{1}{K} \sum_{k=1}^K \mathbb{1}\left[\boldsymbol{z}_k \neq \boldsymbol{v}_k\right]\right] $$ which measures the average error between the output and the label of each reasoning step. ## Theoretical Result ### Settings - Training: $M$ orthonormal training-relevant (TRR) patterns. - Testing: $M^{\prime}$ orthonormal testing-relevant (TSR) patterns. - The reasoning in some examples contains incorrect steps. #### Training data and tasks - Consider $M$ TRR patterns $\mu_1, \mu_2, \cdots, \mu_M$, which form an orthonormal set $\mathcal{M}=\left\{\mu_i\right\}_{i=1}^M . M=\Theta(d), M \leq d .\left(\mu_i^{\top}, 0_{d_{\mathcal{X}}}^{\top}\right)^{\top} \perp c_k$ for $i \in\left[M^{\prime}\right], k \in[K]$. - Training prompt: $\boldsymbol{P}$ contains training examples from the same training task $f\in\mathcal{T}$. - Training task $f=f_K \circ \cdots \circ f_2 \circ f_1$ where $f_k\in\mathcal{F}$. - The $k$-th step label of the query is $\boldsymbol{z}_k=f_k\left(\boldsymbol{z}_{k-1}\right)$ given the $k$-th step input $\boldsymbol{z}_{k-1}$ with $\boldsymbol{z}_k \in \mathcal{M}, k \in[K]$. - Moreover, the $k$-th step label of the $i$-th $\left(i \in\left[l_{t r}\right]\right)$ context example is $\boldsymbol{y}_{i, k}=f_k\left(\boldsymbol{y}_{i, k-1}\right)$ given the $k-1$ th step input $\boldsymbol{y}_{i, k-1}, k \in[K]$ with $\boldsymbol{x}_i, \boldsymbol{y}_{i, k} \in \mathcal{M}$, where $\boldsymbol{y}_{i, 0}:=\boldsymbol{x}_i$. - We assume that $f_k(\boldsymbol{x}) \neq f_{k^{\prime}}\left(\boldsymbol{x}^{\prime}\right)$ if and only if either $\boldsymbol{x} \neq \boldsymbol{x}^{\prime}$ or $f_k \neq f_{k^{\prime}}$. - Let $\alpha \in(0,1-c]$ for some constant $c>0^3$ denote the fraction of context examples with input sharing the same TRR pattern as the query input. #### Testing task and query: - Consider $M^{\prime}$ TSR patterns $\mu_1^{\prime}, \mu_2^{\prime}, \cdots, \mu_M^{\prime}$, which form an orthonormal set $\mathcal{M}^{\prime}=\left\{\mu_i^{\prime}\right\}_{i=1}^{M^{\prime}}, M^{\prime} \leq M$, and $\mu_i^{\prime} \perp c_k$ for $i \in\left[M^{\prime}\right], k \in[K]$. - Let $\mathcal{T}^{\prime}$ denote the set of testing tasks, which all operate on patterns in $\mathcal{M}^{\prime}$ rather than $\mathcal{M}$ in training tasks in $\mathcal{T}$. - Every testing task $f=f_K \circ \cdots f_2 \circ f_1 \in \mathcal{T}^{\prime}$ - The reasoning for the testing query is considered to be noiseless and accurate: $$ \boldsymbol{z}_k \in \mathcal{M}^{\prime} \text { for all } k \in\{0\} \cup[K] \text {, and } \boldsymbol{z}_k=f_k\left(\boldsymbol{z}_{k-1}\right), \boldsymbol{z}_0=\boldsymbol{x}_{\text {query }} . $$ #### Noisy testing prompt - For noisy examples, all inputs and outputs of each step are noisy versions of TSR patterns, i.e., $$ \boldsymbol{x}_i, \boldsymbol{y}_{i, k} \in\left\{b \in \mathbb{R}^d \mid \boldsymbol{b}=\boldsymbol{\mu}_j^{\prime}+\delta, j \in\left[M^{\prime}\right], \delta \perp \mathcal{M}^{\prime},\|\delta\| \leq \sqrt{2} / 2\right\}, $$ with noise $\delta \neq 0$ for $i \in\left[K l_{t s}^f\right], k \in[K]$. - We consider the case that at least an $\alpha^{\prime}$ fraction of context examples where the TSR pattern of the input $\boldsymbol{y}_{s, 1}, s \in\left[l_{t s}^f\right]$ is the same as $\boldsymbol{x}_{\text {query }}$. - To formally model noisy example, we define the step-wise transition matrices $\left\{\boldsymbol{A}_k^f\right\}_{k=1}^K \in \mathbb{R}^{M^{\prime} \times M^{\prime}}$ such that $\boldsymbol{A}_k^f$ represents the reasoning probabilities of step $k$ in test examples. - Specifically, there exists some constant $\rho^f$ in $(0,1)$ such that for all $s \in\left[l_{t s}^f\right], k \in[K]$, the $i, j$-th entry of $\boldsymbol{A}_k^f$ satisfies $$ A_{k(i, j)}^f=\operatorname{Pr}\left(\operatorname{TSR}\left(\boldsymbol{y}_{s, k}\right)=j \mid \operatorname{TSR}\left(\boldsymbol{y}_{s, k-1}\right)=i\right), $$ and $A_{k\left(i, j^*\right)}^f \geq 1 /\left(1-\rho^f\right) \cdot A_{k(i, j)}^f, \forall j \in\left[M^{\prime}\right]$, where $\mu_{j^*}^{\prime}=f_k\left(\mu_i^{\prime}\right)$ and $\mathrm{TSR}: \mathbb{R}^d \mapsto \mathbb{Z}^{+}$ is a function that outputs the index of the TSR pattern of the noisy input. - Let $\boldsymbol{B}^f=\prod_{k=1}^K \boldsymbol{A}_k^f$ be the $K$-step transition matrix, we similarly define $\rho_o^f$ in $(0,1)$ as the primacy of $B^f$, where $$ B_{\left(i, j^*\right)}^f \geq 1 /\left(1-\rho_o^f\right) \cdot B_{(i, j)}^f, \forall j \in\left[M^{\prime}\right], j^*=\arg \max _{j \in\left[M^{\prime}\right]} B_{(i, j)}^f $$ #### Example 1. - Consider a simple two-step inference example with $K=2, \mu_1^{\prime}, \mu_2^{\prime}$ as the TSR pattern, and $\delta=0$ in inputs and outputs of every step. - The black solid arrows denote the correct inference process, where $f_1\left(\mu_1^{\prime}\right)=\mu_1^{\prime}, f_1\left(\mu_2^{\prime}\right)=\mu_2^{\prime}, f_2\left(\mu_1^{\prime}\right)=\mu_2^{\prime}$, and $f_2\left(\mu_2^{\prime}\right)=\mu_1^{\prime}$. Hence, $\mu_1^{\prime} \rightarrow \mu_1^{\prime} \rightarrow \mu_2^{\prime}$ and $\mu_2^{\prime} \rightarrow \mu_2^{\prime} \rightarrow \mu_1^{\prime}$ are two inference trajectories under the function $f$. - The testing examples contain errors and follow the transition matrices $\boldsymbol{A}_1^f$ and $\boldsymbol{A}_2^f$ (brown dashed arrows). - We let $\boldsymbol{A}_1^f=\left(\begin{array}{ll}0.6 & 0.4 \\ 0.4 & 0.6\end{array}\right), \boldsymbol{A}_2^f=\left(\begin{array}{ll}0.4 & 0.6 \\ 0.8 & 0.2\end{array}\right)$, which results in $\boldsymbol{B}^f=\left(\begin{array}{ll}0.56 & 0.44 \\ 0.64 & 0.36\end{array}\right)$. ![image](https://hackmd.io/_uploads/SJ3T3uFmll.png) ### Training Convergence Guarantee --- **Theorem 1.** - For any $\epsilon>0$, when - (i) the number of context examples in every training sample is $$ l_{t r} \geq \Omega\left(\alpha^{-1}\right) $$ - (ii) the number of iterations satisfies $$ T \geq \Omega\left(\eta^{-1} \alpha^{-2} K^3 \log \frac{K}{\epsilon}+\eta^{-1} M K\left(\alpha^{-1}+\epsilon^{-1}\right)\right) $$ - (iii) the training tasks and samples are selected such that every TRR pattern is equally likely in every inference step and in each training batch with batch size $B \geq \Omega\left(\max \left\{\epsilon^{-2}, M\right\} \cdot \log M\right)$, the step size $\eta<1$ and $N=B T$ samples - Then with a high probability, the returned model guarantees $$ \mathbb{E}_{\boldsymbol{x}_{\text {query }} \in \mathcal{M}, f \in \mathcal{T}}[\ell(\Psi ; \boldsymbol{P}, \boldsymbol{z})] \leq \mathcal{O}(\epsilon) . $$ --- ### COT Generalization Guarantee --- **Definition 1.** - For $f=f_K \circ \cdots f_1 \in \mathcal{T}^{\prime}$, we define the min-max trajectory transition probability as: $$ \tau^f=\min _{i \in\left[M^{\prime}\right]} \prod_{k=1}^K A_{k\left(T S R\left(f_{k-1} \circ \cdots f_0\left(\boldsymbol{\mu}_i^{\prime}\right)\right), T S R\left(f_k \circ \cdots f_0\left(\boldsymbol{\mu}_i^{\prime}\right)\right)\right)}^f, \text { where } f_0\left(\boldsymbol{\mu}_i^{\prime}\right):=\boldsymbol{\mu}_i^{\prime}, \forall i \in\left[M^{\prime}\right] \text {, } $$ which measures the minimum probability, over all the initial TSR patterns, of the $K$-step reasoning trajectory that has the highest probability over all $K$-step trajectories. - We also define the min-max input-label transition probability as $$ \tau_o^f=\min _{i \in\left[M^{\prime}\right] j \in\left[M^{\prime}\right]} \max _{i, j} $$ which measures the minimum probability, over all the initial TSR patterns, of the output that has the highest probability over outputs. - For instance, in Example 1, $\tau^f=\min \{0.36,0.48\}=0.36, \tau_o^f=\min \{0.56,0.64\}=0.56$. --- --- **Theorem 2 (CoT generalization).** - Given a trained model, the training process of which satisfies conditions (i) to (iii) in Theorem 1, then as long as - (iv) each TSR pattern $\boldsymbol{\mu}_j^{\prime}$ in the orthonormal set $\left\{\boldsymbol{\mu}_j^{\prime}\right\}_{j=1}^{M^{\prime}}$ satisfies $$ \mu_j^{\prime}=\lambda_j+\tilde{\mu}_j $$ where $\boldsymbol{\lambda}_j \perp \operatorname{span}\left(\boldsymbol{\mu}_1, \cdots, \boldsymbol{\mu}_M\right), \tilde{\boldsymbol{\mu}}_j \in \operatorname{span}\left(\boldsymbol{\mu}_1, \cdots, \boldsymbol{\mu}_M\right)$, and $\left\|\tilde{\boldsymbol{\mu}}_j\right\| \geq \Theta\left(\left(\log \epsilon^{-1}\right)^{-1}\right)$, - (v) the number of testing examples for any $f \in \mathcal{T}^{\prime}$ is $$ l_{t s}^f \geq \Omega\left(\left(\alpha^{\prime} \tau^f \rho^f\right)^{-2} \log M\right) $$ - we have $\bar{R}_{\text {CoT }, \boldsymbol{x}_{\text {query }} \in \mathcal{M}^{\prime}, f \in \mathcal{T}^{\prime}}^f(\Psi)=0$. ---