<style>
.reveal {
font-size: 24px;
}
code[class*="language-"], pre[class*="language-"]{
font-size: 20px;
}
</style>
<!-- .slide: style="font-size: 32px;" -->
# Why think step by step? Reasoning emerges from the locality of experience
https://arxiv.org/pdf/2304.03843.pdf
---
## Chain of Thought(CoT)
https://arxiv.org/pdf/2201.11903.pdf

---
## Abstract
- Investigate why and how **chain-of-thought** is useful.
- Training data consists of overlapping local clusters of variables that influence each other strongly.
- This enable the chaining of accurate local inferences to estimate relationships that were not seen together in training.
- Prove that there will exist a **reasoning gap**, where reasoning through intermediate variables reduces bias.
---
## Introduction
- We may know $A$ and want to know $C$, so we try to estimate $P(C|A)$.
- However, if we have not often seen $A$ and $C$ together, we would struggle to estimate $P(C|A)$ directly.
- If conditioning on an intermediate variable $B$, we can compute:
$$
P(C|A) = \sum_{B}P(C|B)P(B|A).
$$
---
## Example
- To answer “What is the [c]climate of [a]France’s capital?” using a model trained on Wikipedia.
- Wikipedia have
- [a]France's capital is [b]Paris
- [b]Paris has an [c]oceanic climate
- Introduce “[a]France's capital is [b]Paris” before answering “[a]France's capital has an [c]oceanic climate”.
---
## Task setup
- Define a Bayes net with random variables $\{Y_i\}_{i=1}^N$ taking support on a finite set $\mathcal{X}$.
- Denote the distribution of the Bayes net by $p_d$.
- Training data is a sequence of indices $i \in\{1, \ldots N\}$, and values $v_i \in \mathcal{X}$.
- Create local structure via an observation distribution $p_{o b s}$, which is a distribution over subsets of the variable indices.
- Observation distributions take support on a set $\mathcal{Y}_{\text {obs }} \subseteq \mathcal{P}(\{1, \ldots, N\})$.
- The generative process for our sequence consists of two steps.
- Sample a set of variable indices $\left\{i_t\right\}_{t=1}^K$ according to $p_{\mathrm{obs}}$.
- Sample the values $\left\{v_{i_t}\right\}_{t=1}^K$ according to $p_d\left(Y_{i_1}, \ldots, Y_{i_K}\right)$.
---
- A sample from a local neighborhood of a Bayes
```
target: X5
X17=0
X52=1
X24=1
X34=0
X12=1
X20=0
X5=1
```
---
## Estimators
- Given an model $q$ that is trained to predict variable indices and values.
- To compute $Y_i$ and its value $y_i$ given observed $Y_j$ with $y_j$, we use:
- Direct prediction
- Scaffolded generation
- Free generation
---
## Direct prediction
- The baseline without any reasoning.
$$
\hat{q}_D\left(Y_i=y_i \mid Y_j=y_j\right)=q\left(Y_i=y_i \mid Y_j=y_j\right) .
$$
- We simply use the model to directly.
- Example
- To estimate $p\left(X_2 \mid X_1=0\right)$, use the prompt:
```
target: X2
X1=0
X2=
```
- $q$ assign ‘1’ or ‘0’ after X2
---
## Scaffolded generation
- Ideal reasoning if we know the best set of steps.
- Scaffold: ordered set $S$ consisting of variables observed together.
- estimate variable's value given the observed variable and previously generated scaffold using $q$.
$$
\begin{gathered}
\hat{q}_S\left(Y_i=y_i \mid Y_j=y_j\right)=\frac{1}{M} \sum_{k=1}^M q\left(Y_i=y_i \mid\left\{Y_s=y_s^{(k)}\right\}_{s \in S}, Y_j=y_j\right) \\
\text { where } y_s^{(k)} \sim q\left(Y_s \mid\left\{Y_t=y_t^{(k)}\right\}_{t \in S \mid t \prec s}, Y_j=y_j\right)
\end{gathered}
$$
---
- Example
- To estimate $p(X_4|X_1)$ and $X_2$ and $X_3$ were scaffold, start with:
```
target: X4
X1=0
X2=
```
- If got "1" from $q$, give this prompt next:
```
target: X4
X1=0
X2=1
X3=
```
- repeat until all scaffold variables have values:
```
target: X4
X1=0
X2=1
X3=0
X4=
```
---
## Free generation
- Similar to scaffolded, but model choose intermediate variables.
- Sample variable indices and values from $q$ until it generates the index of the target variable.
- Compute the probability of the target variable.
---
- Example
- To inferring $p(X4|X1 = 0)$first prompt like this:
```
target: X4
X1=0
```
- Generate one intermediate variable and its value
```
target: X4
X1=0
X5=1
```
- Repeat until the model outputs the target variable.
```
target: X4
X1=0
X5=1
X2=0
X7=0
X3=1
X4=
```
---
## Experiment

---
## Theoretical Analysis
- Assume that $q$ is trained on a sequence of alternating variable indices $i_t \in$ $\{1, \ldots, N\}$ and variable values $Y_i \in \mathcal{X}$
- Assume that the joint distribution $p_d$ over $Y_i, i \in [N]$. factorizes as $p_d\left(Y_1, \ldots, Y_N\right)=p_d\left(Y_1\right) \prod_{i=1}^N p_d\left(Y_{i+1} \mid Y_i\right)$.
- Assume that $p_{\text {obs }}$ only assigns non-zero probability to adjacent variable pairs, i.e. $p_{\mathrm{obs}}(\{i, j\})=0$ if $|i-j|>1$ and $p_{\mathrm{obs}}(X)=0$ if $|X| \neq 2$.
- The training set consists of i.i.d. samples from $p$, which is the distribution over complete sequences of variable indices and values defined by $p_{o b s}$ and $p_d$.
---
## Theorem 3.1.
- Let $\mathcal{S}$ be the space of sequences of variable indices followed by values.
- Let $u$ be the uniform distribution over $\mathcal{S}$.
- Let $H(p, q)$ denote the cross entropy between $p$ and $q$. Consider the risk:
$$ R(q)=H(p, q)+H(u, q) $$
- Let $q^*=\arg \min _q R(q)$ be a minimizer of the risk.
- Then, for any $y_i, y_j \in \mathcal{X}$ :
$$
\begin{array}{r}
\left|\mathbb{E}_{q_s^*}\left[\hat{q}_S\left(Y_i=y_i \mid Y_j=y_j\right)\right]-p_d\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\
<\left|\hat{q}_D\left(Y_i=y_i \mid Y_j=y_j\right)-p_d\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2
\end{array}
$$
---
## Proposition 1
- Let $\alpha_1 \geq 0, \alpha_2 \geq 0$.
- Let $R(q)=\alpha_1 \mathbb{E}_{p_1(x)}[-\log q(x)]+\alpha_2 \mathbb{E}_{p_2(x)}[-\log q(x)]$.
- Then
$$
q^*=\arg \min R(q)=\frac{\alpha_1}{\alpha_1+\alpha_2} p_1(x)+\frac{\alpha_2}{\alpha_1+\alpha_2} p_2(x)
$$
<!-- ## Preliminaries
- **Factorization of data distribution:**
- Given $Y_1, Y_2, \ldots, Y_N$ take support in some finite set $\mathcal{X}$, we assume that the joint distribution $p_d$ over $Y_1, Y_2, \ldots, Y_N$ factorizes as
$$
p_d\left(Y_1, \ldots Y_N\right)=p_d\left(Y_1\right) \prod_{j=1}^N P_d\left(Y_{j+1} \mid Y_j\right)
$$
- **Local observation distribution:**
- For any non-adjacent pairs $Y_i$ and $Y_j$ and for any variable values $y_i$ and $y_j$,
$$
p\left(i_1=i, v_1=y_i, i_2=j, v_1=y_j\right)=0
$$
- For adjacent pairs $Y_{i+1}$ and $Y_i$ and for variable values $y_{i+1}$ and $y_i$,
$$
p\left(i_1=i, v_1=y_i, i_2=i+1, v_1=y_{i+1}\right) \propto p_d\left(Y_i=y_i, Y_j=y_j\right)
$$
-->
---
## Proof
- The Lagrangian is given by
$$
\begin{aligned}
\mathcal{L}\left(q, \lambda_0\right)&=-\alpha_1 \sum_x p_1(x) \log q(x)-\alpha_2 \sum_x p_2(x) \log q(x) \\
&+\lambda_0\left(\sum_x q(x)-1\right)
\end{aligned}
$$
- The first-order conditions are
$$
\begin{array}{l}
\frac{\partial \mathcal{L}}{\partial q(x)}=-\alpha_1 \frac{p_1(x)}{q(x)}-\alpha_2 \frac{p_2(x)}{q(x)}+\lambda_0=0 \\
\frac{\partial \mathcal{L}}{\partial \lambda_0}=\sum_x q(x)-1=0
\end{array}
$$
- Hence
$$
q(x)=\frac{\alpha_1 p_1(x)+\alpha_2 p_2(x)}{\lambda_0} \Rightarrow \sum_x \frac{\alpha_1 p_1(x)+\alpha_2 p_2(x)}{\lambda_0}=1 \Rightarrow\lambda_0=\alpha_1+\alpha_2
$$ and the result follows immediately.
---
## Theorem A.1.
- Let $u$ be the uniform distribution.
- Let $p$ be the distribution over variable indices and values defined by $p_{\text {obs }}$ and $p_d$.
- Let $H(p, q)$ denote the cross entropy between $p$ and $q$.
- We consider the following risk:
$$
R(q)=H(p, q)+H(u, q).
$$
- Then $q^*=\arg \min _q R(q)$ satisfies the following properties.
- For all pairs of adjacent variables $Y_i$ and $Y_{i-1}$, $$ q^*\left(Y_i \mid Y_{i-1}\right)=\lambda p_d\left(Y_i \mid Y_{i-1}\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}, \quad \text{for some $\lambda \in(0,1)$}
$$
- For all pairs of non-adjacent variables,
$$q^*\left(Y_i \mid Y_j\right)=\frac{1}{|\mathcal{X}|}.
$$
---
## Proof.
- We can write the risk as a sum across timesteps.
$$
R(q)=\sum_{t=1}^T \mathbb{E}_{p\left(x_{1: t}\right)}\left[-\log q\left(x_t \mid x_{1: t-1}\right)\right]+\sum_{t=1}^T \mathbb{E}_{u\left(x_{1: t}\right)}\left[-\log q\left(x_t \mid x_{1: t-1}\right)\right]
$$
- Decompose the fourth term in the left term:
$$
\begin{array}{l}
\mathbb{E}_{p\left(i_{1: 2}, v_{1: 2}\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] \\
=\mathbb{E}_{p\left(i_{1: 2}, v_1\right)}\left[\mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right]\right] \\
=\sum_{i_1} \sum_{v_1} \sum_{i_2} \mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] p\left(i_{1: 2}, v_1\right) \\
=\sum_{\left(i_1, i_2\right) \in \mathcal{Y}_{\text {obs }}} \sum_{v_1} \mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] p\left(i_{1: 2}, v_1\right)
\end{array}
$$
- Decompose the fourth term in the right term:
$$
\begin{array}{l}
\mathbb{E}_{u\left(i_{1: 2}, v_{1: 2}\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] \\
=\mathbb{E}_{u\left(i_{1: 2}, v_1\right)}\left[\mathbb{E}_{u\left(v_2 \mid k_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right]\right] \\
=\sum_{i_1} \sum_{v_1} \sum_{i_2} \mathbb{E}_{u\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] u\left(i_{1: 2}, v_1\right)
\end{array}
$$
---
- Adjacent pairs:
- Suppose $i_2=i+1$ and $i_1=i$ for some $i$. In addition, fix some value for $v_1$:
$$
\begin{array}{r}
\frac{1}{2} \mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] p\left(i_{1: 2}, v_1\right) \\
+\frac{1}{2} \mathbb{E}_{u\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right] u\left(i_{1: 2}, v_1\right)\right.
\end{array}
$$
- By Proposition 1, the sum is minimized by
$$
q^*\left(Y_i \mid Y_j\right)=\lambda_{i, j} \frac{1}{|\mathcal{X}|}+(1-\left.\lambda_{i, j}\right) p\left(Y_i \mid Y_j\right)
$$
for $\lambda_{i, j}=\frac{u\left(i_{1: 2}, v_1\right)}{u\left(i_{1: 2}, v_1\right)+p\left(i_{1:2},v_1\right)}$.
- Non-adjacent pairs:
- The left term is zero since $p(i_1=i,i_2=j) =0$.
- Therefore, $q^*\left(Y_i \mid Y_j\right)=\frac{1}{|\mathcal{X}|}$.
---
## Theorem A.2.
- We assume that $\sum_{y_i} p\left(Y_i=y_i \mid Y_j=y_j\right)=1=\sum_{y_j} p\left(Y_i=y_i \mid Y_j=y_j\right)$.
- For all $y_i, y_j \in \mathcal{X}$ with $|i-j|>1$,
$$
\begin{aligned}
&\left|\mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right]-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\
& < \left|\hat{q}_D^*\left(Y_i=y_i \mid Y_j=y_j\right)-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2
\end{aligned}
$$
---
<!-- .slide: style="font-size: 20px;" -->
## Proof.
- First, we show that
$$
\begin{aligned}
& \mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right] =\lambda p\left(Y_i=y_i \mid Y_j=y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}
\end{aligned}
$$
- We prove this by induction on $|i-j|$. Consider the base case $|i-j|=2$.
$$
\begin{aligned}
\mathbb{E}& \left[\hat{q}_S^*\left(Y_3=y_3 \mid Y_1=y_1\right)\right] = \mathbb{E} \left[\left(1-\lambda_{3,2}\right) p\left(Y_3=y_3 \mid Y_2=y_2\right)+\lambda_{3,2} \frac{1}{|\mathcal{X}|}\right] \\
= & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right) \mathbb{E}\left[p\left(Y_3=y_3 \mid Y_2=Y_2\right)\right] \\
= & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right) \sum_{y_2} p\left(Y_3=y_3 \mid Y_2=y_2\right) \left[\left(1-\lambda_{2,1}\right) p\left(Y_2=y_2 \mid Y_1=y_1\right)+\lambda_{2,1} \frac{1}{|\mathcal{X}|}\right] \\
= & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right)\left(1-\lambda_{2,1}\right) \sum_{y_2} p\left(Y_3=y_3 \mid Y_2=y_2\right) p\left(Y_2=y_2 \mid Y_1=y_1\right) \\
& +\left(1-\lambda_{3,2}\right) \lambda_{2,1} \sum_{y_2} \frac{1}{|\mathcal{X}|} p\left(Y_3=y_3 \mid Y_2=y_2\right) \\
= & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right)\left(1-\lambda_{2,1}\right) p\left(Y_3=y_3 \mid Y_1=y_1\right) +\left(1-\lambda_{3,2}\right) \lambda_{2,1} \frac{1}{|\mathcal{X}|} \\
= & (1-\lambda) \frac{1}{|\mathcal{X}|}+\lambda p\left(Y_3=y_3 \mid Y_1=y_1\right) \quad \text{where $\lambda=\left(1-\lambda_{3,2}\right)\left(1-\lambda_{2,1}\right)$}
\end{aligned}
$$
---
- For the induction step, we note that the expectations $\mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right]$ and $\mathbb{E}\left[\hat{q}_S^*\left(Y_{i-1}=y_{i-1} \mid Y_j=y_j\right)\right]$ are related as follows
$$
\begin{aligned}
& \mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right] \\
& =\lambda \frac{1}{|\mathcal{X}|}+(1-\lambda) \sum_{y_{i-1}} p\left(Y_i=y_i \mid Y_{i-1}=y_{i-1}\right) \times \mathbb{E}\left[\hat{q}_S^*\left(Y_{i-1}=y_{i-1} \mid Y_j=y_j\right)\right]
\end{aligned}
$$
---
- The bias of direct estimator can be computed as $\left|\frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2$.
- By the previous results, $\mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right]=\lambda p\left(Y_i \mid Y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}$ for some $\lambda \in(0,1)$.
$$
\begin{array}{l}
\left|\mathbb{E}\left[\hat{q}_S^*\left(Y_i \mid Y_j=y_j\right)\right]-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\
=\left|\lambda p\left(Y_i=y_i \mid Y_j=y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\
=\left|\lambda p\left(Y_i=y_i \mid Y_j=y_j\right)+\frac{1}{|\mathcal{X}|}-\lambda \frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\
=\left|(\lambda-1) p\left(Y_i=y_i \mid Y_j=y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}\right|^2 \\
=(1-\lambda)^2 \left|\frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2. \\
<\left|\frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2
\end{array}
$$
<!--
## Reason for Chain-of-Thought
- direct prediction is inaccurate for some inferences because the relevant variables are rarely seen together in training.
- chain-of-thought reasoning improves estimation by incrementally chaining local statistical dependencies that are observed frequently in training.
- The combination of locally structured training data and reasoning with self-generated intermediate variables yields much greater data efficiency than training on data containing all variables.
---
-->