# Training Nonlinear Transformers for Chain-of-Thought Inference: A Theoretical Generalization Analysis
https://arxiv.org/abs/2410.02167
## Abstract
- Provide the theoretical study of training Transformers to obtain the CoT generalization capability on unseen tasks.
- Fist, Quantify the required training samples and iterations to train a Transformer model towards CoT ability.
- Then, prove the success of its CoT generalization on unseen tasks with distribution-shifted testing data.
- Moreover, theoretically characterize the conditions for an accurate output by CoT even when the provided examples contain noise.
## Problem Formulation
- $K$-steps reasoning tasks.
- Each task $f=$ $f_K \circ \cdots \cdot f_2 \circ f_1$ is a composition of functions $\left\{f_i\right\}_{i=1}^K$ and outputs labels $\boldsymbol{z}_1, \boldsymbol{z}_2, \cdots, \boldsymbol{z}_K$ for the input $\boldsymbol{x}_{\text {query }}$.
- During the $k$-th reasoning step, $k \in[K]$, the label is $\boldsymbol{z}_k=f_k\left(\boldsymbol{z}_{k-1}\right)$, where $\boldsymbol{z}_0:=\boldsymbol{x}_{\text {query }}$.
### Training
#### Prompts and labels
- The training prompt includes multiple $K$-steps reasoning examples and a ( $k-1$ )-step reasoning of $\boldsymbol{x}_{\text {query }}$ for any $k$ in $[K]$, and the label for this prompt is $\boldsymbol{z}_k$.
- For the query input $\boldsymbol{z}_{k-1}$, in-context example of prompt $\boldsymbol{P}$ contains $l_{t r}$ reasoning examples:
$$
\begin{aligned}
\boldsymbol{P} & =\left(\boldsymbol{E}_1, \boldsymbol{E}_2, \cdots, \boldsymbol{E}_{l_{t r}}, \boldsymbol{Q}_k\right) \in \mathbb{R}^{2 d_{\mathcal{X}} \times\left(l_{t r} K+k\right)} \\
\text { where } \boldsymbol{E}_i & =\left(\begin{array}{cccc}
\boldsymbol{x}_i & \boldsymbol{y}_{i, 1} & \cdots & \boldsymbol{y}_{i, K-1} \\
\boldsymbol{y}_{i, 1} & \boldsymbol{y}_{i, 2} & \cdots & \boldsymbol{y}_{i, K}
\end{array}\right), \boldsymbol{Q}_k=\left(\begin{array}{ccccc}
\boldsymbol{z}_0 & \boldsymbol{z}_1 & \cdots & \boldsymbol{z}_{k-2} & \boldsymbol{z}_{k-1} \\
\boldsymbol{z}_1 & \boldsymbol{z}_2 & \cdots & \boldsymbol{z}_{k-1} & \mathbf{0}
\end{array}\right), i \in\left[l_{t r}\right]
\end{aligned}
$$
- Let $\boldsymbol{p}_s$ and $\boldsymbol{p}_{\text {query }}$ be the $s$-th column and the last column of $\boldsymbol{P}$, respectively, for $s \in\left[l_{t r} K+k-1\right] . \boldsymbol{x}_i, \boldsymbol{y}_{i, k}, \boldsymbol{z}_j \in \mathbb{R}^{d_{\mathcal{X}}}$ for $i \in\left[l_{t r}\right]$ and $j, k \in[K]$.
- Positional encoding $\left\{\boldsymbol{c}_k\right\}_{k=1}^K \mathbb{I} \in \mathbb{R}^{2 d_{\mathcal{X}}}$, which is added to $\boldsymbol{p}$ by $\tilde{\boldsymbol{p}}_i=\boldsymbol{p}_i+\boldsymbol{c}_{(i \bmod K)}, i \in\left[K\left(l_{t r}+1\right)\right]$.
#### Model
- Given a prompt $\boldsymbol{P}$ with $\operatorname{len}(\boldsymbol{P})$ denoting the number of columns, it can be written as
$$
F(\Psi ; \boldsymbol{P})=\sum_{i=1}^{\operatorname{len}(\boldsymbol{P})-1} \boldsymbol{W}_V \tilde{\boldsymbol{p}}_i \cdot \operatorname{softmax}\left(\left(\boldsymbol{W}_K \tilde{\boldsymbol{p}}_i\right)^{\top} \boldsymbol{W}_Q \tilde{\boldsymbol{p}}_{q u e r y}\right),
$$
where $W_Q, W_K \in \mathbb{R}^{m \times\left(2 d_{\mathcal{X}}\right)}, W_V \in \mathbb{R}^{d_{\mathcal{X}} \times\left(2 d_{\mathcal{X}}\right)}$, and $\Psi:=\left\{W_Q, W_K, W_V\right\}$ is the set of all model weights.
#### Loss function
- The trainer use $N$ prompt and label pairs $\left\{\boldsymbol{P}^n, \boldsymbol{z}^n\right\}_{n=1}^N$ to minimize the empirical risk
$$
\min _{\Psi} R_N(\Psi):=\frac{1}{N} \sum_{n=1}^N \ell\left(\Psi ; P^n, z^n\right)
$$
- For the $n$-th sample, $\boldsymbol{x}_{\text {query }}^n$ and the context input $\boldsymbol{x}_i^n$ are all sampled from an unknown distribution $\mathcal{D}$, the training task $f^n$ is sampled from $\mathcal{T}, k$ is randomly selected from 1 to $K$.
- The loss function is squared loss, i.e., $\ell\left(\Psi ; \boldsymbol{P}^n, \boldsymbol{z}^n\right)=1 / 2 \cdot\left\|\boldsymbol{z}^n-F\left(\Psi ; \boldsymbol{P}^n\right)\right\|^2$.
### Inference
- Consider another $K$-steps reasoning task $f \in \mathcal{T}^{\prime}$, whose target is to predict labels $\left\{\boldsymbol{z}_k\right\}_{k=1}^K$ given the input query $x_{\text {query }}. \mathcal{T}^{\prime}$ is the set of testing tasks, and $\mathcal{T}^{\prime} \neq \mathcal{T}$.
- The testing prompt $\boldsymbol{P}$ is composed of $l_{t s}\left(\leq l_{t r}\right)$ context examples of $K$ steps plus a query:
$$
\boldsymbol{P}=\left(\boldsymbol{E}_1, \boldsymbol{E}_2, \cdots, \boldsymbol{E}_{l_{t s}}, \boldsymbol{p}_{q u e r y}\right) \in \mathbb{R}^{\left(2 d_{\mathcal{X}}\right) \times\left(l_{t s} K+1\right)}, \boldsymbol{p}_{q u e r y}=\left(\boldsymbol{x}_{q u e r y}^{\top}, \mathbf{0}^{\top}\right)^{\top}
$$
- The inference process for a $K$-step CoT with $l_{t s}$ examples on a certain $f \in \mathcal{T}^{\prime}$, given the testing prompt $\boldsymbol{P}$, let $\boldsymbol{P}_1=\boldsymbol{P}$ and $\boldsymbol{P}_0$ be the first $K \cdot l_{t s}$ columns of $\boldsymbol{P}$, predicting the $k$-th step:
- Generate the output $\boldsymbol{v}_k, k \in[K]$ via greedy decoding by feeding the $k$-th step prompt $\boldsymbol{P}_k$ to the trained model $\Psi$.
- The greedy decoding scheme means outputting the most probable token from the discrete set $\mathcal{Y}$:
$$
\boldsymbol{v}_k=\arg \min _{\boldsymbol{u} \in \mathcal{Y}} \frac{1}{2}\left\|F\left(\Psi ; \boldsymbol{P}_k\right)-\boldsymbol{u}\right\|^2, \text { (greedy decoding) }
$$
- Then, use the output $\boldsymbol{v}_k$ to update $\boldsymbol{P}_k$ and use $\boldsymbol{v}_k$ as the query input to form the input prompt $\boldsymbol{P}_{k+1}$ for the next step:
$$
\begin{aligned}
& \boldsymbol{P}_k=\left(\boldsymbol{P}_{k-1} \boldsymbol{q}_k\right) \in \mathbb{R}^{\left(2 d_{\mathcal{X}}\right) \times\left(K l_{t s}+k\right)}, \boldsymbol{P}_{k+1}=\left(\boldsymbol{P}_k \boldsymbol{q}_{k+1}\right) \in \mathbb{R}^{\left(2 d_{\mathcal{X}}\right) \times\left(K l_{t s}+k+1\right)}, \\
& \text { where } \boldsymbol{q}_k=\left(\boldsymbol{v}_{k-1}^{\top} \boldsymbol{v}_k^{\top}\right)^{\top}, \boldsymbol{q}_{k+1}=\left(\boldsymbol{v}_k^{\top} \mathbf{0}^{\top}\right)^{\top}
\end{aligned}
$$
- The model finally outputs $\boldsymbol{v}_1, \cdots, \boldsymbol{v}_K$ as CoT result for query $x_{\text {query }}$.
- The CoT generalization error given the testing query $\boldsymbol{x}_{\text {query }}$, the testing data distribution $\mathcal{D}^{\prime}$, and the labels $\left\{\boldsymbol{z}_k\right\}_{k=1}^K$ on a $K$-steps testing task $f \in \mathcal{T}^{\prime}$ is defined as
$$
\bar{R}_{C o T, \boldsymbol{x}_{q u e r y} \sim \mathcal{D}^{\prime}, f \in \mathcal{T}^{\prime}}^f(\Psi)=\mathbb{E}_{\boldsymbol{x}_{q u e r y} \sim \mathcal{D}^{\prime}}\left[\frac{1}{K} \sum_{k=1}^K \mathbb{1}\left[\boldsymbol{z}_k \neq \boldsymbol{v}_k\right]\right]
$$
which measures the average error between the output and the label of each reasoning step.
## Theoretical Result
### Settings
- Training: $M$ orthonormal training-relevant (TRR) patterns.
- Testing: $M^{\prime}$ orthonormal testing-relevant (TSR) patterns.
- The reasoning in some examples contains incorrect steps.
#### Training data and tasks
- Consider $M$ TRR patterns $\mu_1, \mu_2, \cdots, \mu_M$, which form an orthonormal set $\mathcal{M}=\left\{\mu_i\right\}_{i=1}^M . M=\Theta(d), M \leq d .\left(\mu_i^{\top}, 0_{d_{\mathcal{X}}}^{\top}\right)^{\top} \perp c_k$ for $i \in\left[M^{\prime}\right], k \in[K]$.
- Training prompt: $\boldsymbol{P}$ contains training examples from the same training task $f\in\mathcal{T}$.
- Training task $f=f_K \circ \cdots \circ f_2 \circ f_1$ where $f_k\in\mathcal{F}$.
- The $k$-th step label of the query is $\boldsymbol{z}_k=f_k\left(\boldsymbol{z}_{k-1}\right)$ given the $k$-th step input $\boldsymbol{z}_{k-1}$ with $\boldsymbol{z}_k \in \mathcal{M}, k \in[K]$.
- Moreover, the $k$-th step label of the $i$-th $\left(i \in\left[l_{t r}\right]\right)$ context example is $\boldsymbol{y}_{i, k}=f_k\left(\boldsymbol{y}_{i, k-1}\right)$ given the $k-1$ th step input $\boldsymbol{y}_{i, k-1}, k \in[K]$ with $\boldsymbol{x}_i, \boldsymbol{y}_{i, k} \in \mathcal{M}$, where $\boldsymbol{y}_{i, 0}:=\boldsymbol{x}_i$.
- We assume that $f_k(\boldsymbol{x}) \neq f_{k^{\prime}}\left(\boldsymbol{x}^{\prime}\right)$ if and only if either $\boldsymbol{x} \neq \boldsymbol{x}^{\prime}$ or $f_k \neq f_{k^{\prime}}$.
- Let $\alpha \in(0,1-c]$ for some constant $c>0^3$ denote the fraction of context examples with input sharing the same TRR pattern as the query input.
#### Testing task and query:
- Consider $M^{\prime}$ TSR patterns $\mu_1^{\prime}, \mu_2^{\prime}, \cdots, \mu_M^{\prime}$, which form an orthonormal set $\mathcal{M}^{\prime}=\left\{\mu_i^{\prime}\right\}_{i=1}^{M^{\prime}}, M^{\prime} \leq M$, and $\mu_i^{\prime} \perp c_k$ for $i \in\left[M^{\prime}\right], k \in[K]$.
- Let $\mathcal{T}^{\prime}$ denote the set of testing tasks, which all operate on patterns in $\mathcal{M}^{\prime}$ rather than $\mathcal{M}$ in training tasks in $\mathcal{T}$.
- Every testing task $f=f_K \circ \cdots f_2 \circ f_1 \in \mathcal{T}^{\prime}$
- The reasoning for the testing query is considered to be noiseless and accurate:
$$
\boldsymbol{z}_k \in \mathcal{M}^{\prime} \text { for all } k \in\{0\} \cup[K] \text {, and } \boldsymbol{z}_k=f_k\left(\boldsymbol{z}_{k-1}\right), \boldsymbol{z}_0=\boldsymbol{x}_{\text {query }} .
$$
#### Noisy testing prompt
- For noisy examples, all inputs and outputs of each step are noisy versions of TSR patterns, i.e.,
$$
\boldsymbol{x}_i, \boldsymbol{y}_{i, k} \in\left\{b \in \mathbb{R}^d \mid \boldsymbol{b}=\boldsymbol{\mu}_j^{\prime}+\delta, j \in\left[M^{\prime}\right], \delta \perp \mathcal{M}^{\prime},\|\delta\| \leq \sqrt{2} / 2\right\},
$$
with noise $\delta \neq 0$ for $i \in\left[K l_{t s}^f\right], k \in[K]$.
- We consider the case that at least an $\alpha^{\prime}$ fraction of context examples where the TSR pattern of the input $\boldsymbol{y}_{s, 1}, s \in\left[l_{t s}^f\right]$ is the same as $\boldsymbol{x}_{\text {query }}$.
- To formally model noisy example, we define the step-wise transition matrices $\left\{\boldsymbol{A}_k^f\right\}_{k=1}^K \in \mathbb{R}^{M^{\prime} \times M^{\prime}}$ such that $\boldsymbol{A}_k^f$ represents the reasoning probabilities of step $k$ in test examples.
- Specifically, there exists some constant $\rho^f$ in $(0,1)$ such that for all $s \in\left[l_{t s}^f\right], k \in[K]$, the $i, j$-th entry of $\boldsymbol{A}_k^f$ satisfies
$$
A_{k(i, j)}^f=\operatorname{Pr}\left(\operatorname{TSR}\left(\boldsymbol{y}_{s, k}\right)=j \mid \operatorname{TSR}\left(\boldsymbol{y}_{s, k-1}\right)=i\right),
$$
and $A_{k\left(i, j^*\right)}^f \geq 1 /\left(1-\rho^f\right) \cdot A_{k(i, j)}^f, \forall j \in\left[M^{\prime}\right]$, where $\mu_{j^*}^{\prime}=f_k\left(\mu_i^{\prime}\right)$ and $\mathrm{TSR}: \mathbb{R}^d \mapsto \mathbb{Z}^{+}$ is a function that outputs the index of the TSR pattern of the noisy input.
- Let $\boldsymbol{B}^f=\prod_{k=1}^K \boldsymbol{A}_k^f$ be the $K$-step transition matrix, we similarly define $\rho_o^f$ in $(0,1)$ as the primacy of $B^f$, where
$$
B_{\left(i, j^*\right)}^f \geq 1 /\left(1-\rho_o^f\right) \cdot B_{(i, j)}^f, \forall j \in\left[M^{\prime}\right], j^*=\arg \max _{j \in\left[M^{\prime}\right]} B_{(i, j)}^f
$$
#### Example 1.
- Consider a simple two-step inference example with $K=2, \mu_1^{\prime}, \mu_2^{\prime}$ as the TSR pattern, and $\delta=0$ in inputs and outputs of every step.
- The black solid arrows denote the correct inference process, where $f_1\left(\mu_1^{\prime}\right)=\mu_1^{\prime}, f_1\left(\mu_2^{\prime}\right)=\mu_2^{\prime}, f_2\left(\mu_1^{\prime}\right)=\mu_2^{\prime}$, and $f_2\left(\mu_2^{\prime}\right)=\mu_1^{\prime}$. Hence, $\mu_1^{\prime} \rightarrow \mu_1^{\prime} \rightarrow \mu_2^{\prime}$ and $\mu_2^{\prime} \rightarrow \mu_2^{\prime} \rightarrow \mu_1^{\prime}$ are two inference trajectories under the function $f$.
- The testing examples contain errors and follow the transition matrices $\boldsymbol{A}_1^f$ and $\boldsymbol{A}_2^f$ (brown dashed arrows).
- We let $\boldsymbol{A}_1^f=\left(\begin{array}{ll}0.6 & 0.4 \\ 0.4 & 0.6\end{array}\right), \boldsymbol{A}_2^f=\left(\begin{array}{ll}0.4 & 0.6 \\ 0.8 & 0.2\end{array}\right)$, which results in $\boldsymbol{B}^f=\left(\begin{array}{ll}0.56 & 0.44 \\ 0.64 & 0.36\end{array}\right)$.

### Training Convergence Guarantee
---
**Theorem 1.**
- For any $\epsilon>0$, when
- (i) the number of context examples in every training sample is
$$
l_{t r} \geq \Omega\left(\alpha^{-1}\right)
$$
- (ii) the number of iterations satisfies
$$
T \geq \Omega\left(\eta^{-1} \alpha^{-2} K^3 \log \frac{K}{\epsilon}+\eta^{-1} M K\left(\alpha^{-1}+\epsilon^{-1}\right)\right)
$$
- (iii) the training tasks and samples are selected such that every TRR pattern is equally likely in every inference step and in each training batch with batch size $B \geq \Omega\left(\max \left\{\epsilon^{-2}, M\right\} \cdot \log M\right)$, the step size $\eta<1$ and $N=B T$ samples
- Then with a high probability, the returned model guarantees
$$
\mathbb{E}_{\boldsymbol{x}_{\text {query }} \in \mathcal{M}, f \in \mathcal{T}}[\ell(\Psi ; \boldsymbol{P}, \boldsymbol{z})] \leq \mathcal{O}(\epsilon) .
$$
---
### COT Generalization Guarantee
---
**Definition 1.**
- For $f=f_K \circ \cdots f_1 \in \mathcal{T}^{\prime}$, we define the min-max trajectory transition probability as:
$$
\tau^f=\min _{i \in\left[M^{\prime}\right]} \prod_{k=1}^K A_{k\left(T S R\left(f_{k-1} \circ \cdots f_0\left(\boldsymbol{\mu}_i^{\prime}\right)\right), T S R\left(f_k \circ \cdots f_0\left(\boldsymbol{\mu}_i^{\prime}\right)\right)\right)}^f, \text { where } f_0\left(\boldsymbol{\mu}_i^{\prime}\right):=\boldsymbol{\mu}_i^{\prime}, \forall i \in\left[M^{\prime}\right] \text {, }
$$
which measures the minimum probability, over all the initial TSR patterns, of the $K$-step reasoning trajectory that has the highest probability over all $K$-step trajectories.
- We also define the min-max input-label transition probability as
$$
\tau_o^f=\min _{i \in\left[M^{\prime}\right] j \in\left[M^{\prime}\right]} \max _{i, j}
$$
which measures the minimum probability, over all the initial TSR patterns, of the output that has the highest probability over outputs.
- For instance, in Example 1, $\tau^f=\min \{0.36,0.48\}=0.36, \tau_o^f=\min \{0.56,0.64\}=0.56$.
---
---
**Theorem 2 (CoT generalization).**
- Given a trained model, the training process of which satisfies conditions (i) to (iii) in Theorem 1, then as long as
- (iv) each TSR pattern $\boldsymbol{\mu}_j^{\prime}$ in the orthonormal set $\left\{\boldsymbol{\mu}_j^{\prime}\right\}_{j=1}^{M^{\prime}}$ satisfies
$$
\mu_j^{\prime}=\lambda_j+\tilde{\mu}_j
$$
where $\boldsymbol{\lambda}_j \perp \operatorname{span}\left(\boldsymbol{\mu}_1, \cdots, \boldsymbol{\mu}_M\right), \tilde{\boldsymbol{\mu}}_j \in \operatorname{span}\left(\boldsymbol{\mu}_1, \cdots, \boldsymbol{\mu}_M\right)$, and $\left\|\tilde{\boldsymbol{\mu}}_j\right\| \geq \Theta\left(\left(\log \epsilon^{-1}\right)^{-1}\right)$,
- (v) the number of testing examples for any $f \in \mathcal{T}^{\prime}$ is
$$
l_{t s}^f \geq \Omega\left(\left(\alpha^{\prime} \tau^f \rho^f\right)^{-2} \log M\right)
$$
- we have $\bar{R}_{\text {CoT }, \boldsymbol{x}_{\text {query }} \in \mathcal{M}^{\prime}, f \in \mathcal{T}^{\prime}}^f(\Psi)=0$.
---