<style> .reveal { font-size: 24px; } code[class*="language-"], pre[class*="language-"]{ font-size: 20px; } </style> <!-- .slide: style="font-size: 32px;" --> # Why think step by step? Reasoning emerges from the locality of experience https://arxiv.org/pdf/2304.03843.pdf --- ## Chain of Thought(CoT) https://arxiv.org/pdf/2201.11903.pdf ![圖片](https://hackmd.io/_uploads/ryDJBWWZC.png) --- ## Abstract - Investigate why and how **chain-of-thought** is useful. - Training data consists of overlapping local clusters of variables that influence each other strongly. - This enable the chaining of accurate local inferences to estimate relationships that were not seen together in training. - Prove that there will exist a **reasoning gap**, where reasoning through intermediate variables reduces bias. --- ## Introduction - We may know $A$ and want to know $C$, so we try to estimate $P(C|A)$. - However, if we have not often seen $A$ and $C$ together, we would struggle to estimate $P(C|A)$ directly. - If conditioning on an intermediate variable $B$, we can compute: $$ P(C|A) = \sum_{B}P(C|B)P(B|A). $$ --- ## Example - To answer “What is the [c]climate of [a]France’s capital?” using a model trained on Wikipedia. - Wikipedia have - [a]France's capital is [b]Paris - [b]Paris has an [c]oceanic climate - Introduce “[a]France's capital is [b]Paris” before answering “[a]France's capital has an [c]oceanic climate”. --- ## Task setup - Define a Bayes net with random variables $\{Y_i\}_{i=1}^N$ taking support on a finite set $\mathcal{X}$. - Denote the distribution of the Bayes net by $p_d$. - Training data is a sequence of indices $i \in\{1, \ldots N\}$, and values $v_i \in \mathcal{X}$. - Create local structure via an observation distribution $p_{o b s}$, which is a distribution over subsets of the variable indices. - Observation distributions take support on a set $\mathcal{Y}_{\text {obs }} \subseteq \mathcal{P}(\{1, \ldots, N\})$. - The generative process for our sequence consists of two steps. - Sample a set of variable indices $\left\{i_t\right\}_{t=1}^K$ according to $p_{\mathrm{obs}}$. - Sample the values $\left\{v_{i_t}\right\}_{t=1}^K$ according to $p_d\left(Y_{i_1}, \ldots, Y_{i_K}\right)$. --- - A sample from a local neighborhood of a Bayes ``` target: X5 X17=0 X52=1 X24=1 X34=0 X12=1 X20=0 X5=1 ``` --- ## Estimators - Given an model $q$ that is trained to predict variable indices and values. - To compute $Y_i$ and its value $y_i$ given observed $Y_j$ with $y_j$, we use: - Direct prediction - Scaffolded generation - Free generation --- ## Direct prediction - The baseline without any reasoning. $$ \hat{q}_D\left(Y_i=y_i \mid Y_j=y_j\right)=q\left(Y_i=y_i \mid Y_j=y_j\right) . $$ - We simply use the model to directly. - Example - To estimate $p\left(X_2 \mid X_1=0\right)$, use the prompt: ``` target: X2 X1=0 X2= ``` - $q$ assign ‘1’ or ‘0’ after X2 --- ## Scaffolded generation - Ideal reasoning if we know the best set of steps. - Scaffold: ordered set $S$ consisting of variables observed together. - estimate variable's value given the observed variable and previously generated scaffold using $q$. $$ \begin{gathered} \hat{q}_S\left(Y_i=y_i \mid Y_j=y_j\right)=\frac{1}{M} \sum_{k=1}^M q\left(Y_i=y_i \mid\left\{Y_s=y_s^{(k)}\right\}_{s \in S}, Y_j=y_j\right) \\ \text { where } y_s^{(k)} \sim q\left(Y_s \mid\left\{Y_t=y_t^{(k)}\right\}_{t \in S \mid t \prec s}, Y_j=y_j\right) \end{gathered} $$ --- - Example - To estimate $p(X_4|X_1)$ and $X_2$ and $X_3$ were scaffold, start with: ``` target: X4 X1=0 X2= ``` - If got "1" from $q$, give this prompt next: ``` target: X4 X1=0 X2=1 X3= ``` - repeat until all scaffold variables have values: ``` target: X4 X1=0 X2=1 X3=0 X4= ``` --- ## Free generation - Similar to scaffolded, but model choose intermediate variables. - Sample variable indices and values from $q$ until it generates the index of the target variable. - Compute the probability of the target variable. --- - Example - To inferring $p(X4|X1 = 0)$first prompt like this: ``` target: X4 X1=0 ``` - Generate one intermediate variable and its value ``` target: X4 X1=0 X5=1 ``` - Repeat until the model outputs the target variable. ``` target: X4 X1=0 X5=1 X2=0 X7=0 X3=1 X4= ``` --- ## Experiment ![圖片](https://hackmd.io/_uploads/ry4e89D1A.png) --- ## Theoretical Analysis - Assume that $q$ is trained on a sequence of alternating variable indices $i_t \in$ $\{1, \ldots, N\}$ and variable values $Y_i \in \mathcal{X}$ - Assume that the joint distribution $p_d$ over $Y_i, i \in [N]$. factorizes as $p_d\left(Y_1, \ldots, Y_N\right)=p_d\left(Y_1\right) \prod_{i=1}^N p_d\left(Y_{i+1} \mid Y_i\right)$. - Assume that $p_{\text {obs }}$ only assigns non-zero probability to adjacent variable pairs, i.e. $p_{\mathrm{obs}}(\{i, j\})=0$ if $|i-j|>1$ and $p_{\mathrm{obs}}(X)=0$ if $|X| \neq 2$. - The training set consists of i.i.d. samples from $p$, which is the distribution over complete sequences of variable indices and values defined by $p_{o b s}$ and $p_d$. --- ## Theorem 3.1. - Let $\mathcal{S}$ be the space of sequences of variable indices followed by values. - Let $u$ be the uniform distribution over $\mathcal{S}$. - Let $H(p, q)$ denote the cross entropy between $p$ and $q$. Consider the risk: $$ R(q)=H(p, q)+H(u, q) $$ - Let $q^*=\arg \min _q R(q)$ be a minimizer of the risk. - Then, for any $y_i, y_j \in \mathcal{X}$ : $$ \begin{array}{r} \left|\mathbb{E}_{q_s^*}\left[\hat{q}_S\left(Y_i=y_i \mid Y_j=y_j\right)\right]-p_d\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\ <\left|\hat{q}_D\left(Y_i=y_i \mid Y_j=y_j\right)-p_d\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \end{array} $$ --- ## Proposition 1 - Let $\alpha_1 \geq 0, \alpha_2 \geq 0$. - Let $R(q)=\alpha_1 \mathbb{E}_{p_1(x)}[-\log q(x)]+\alpha_2 \mathbb{E}_{p_2(x)}[-\log q(x)]$. - Then $$ q^*=\arg \min R(q)=\frac{\alpha_1}{\alpha_1+\alpha_2} p_1(x)+\frac{\alpha_2}{\alpha_1+\alpha_2} p_2(x) $$ <!-- ## Preliminaries - **Factorization of data distribution:** - Given $Y_1, Y_2, \ldots, Y_N$ take support in some finite set $\mathcal{X}$, we assume that the joint distribution $p_d$ over $Y_1, Y_2, \ldots, Y_N$ factorizes as $$ p_d\left(Y_1, \ldots Y_N\right)=p_d\left(Y_1\right) \prod_{j=1}^N P_d\left(Y_{j+1} \mid Y_j\right) $$ - **Local observation distribution:** - For any non-adjacent pairs $Y_i$ and $Y_j$ and for any variable values $y_i$ and $y_j$, $$ p\left(i_1=i, v_1=y_i, i_2=j, v_1=y_j\right)=0 $$ - For adjacent pairs $Y_{i+1}$ and $Y_i$ and for variable values $y_{i+1}$ and $y_i$, $$ p\left(i_1=i, v_1=y_i, i_2=i+1, v_1=y_{i+1}\right) \propto p_d\left(Y_i=y_i, Y_j=y_j\right) $$ --> --- ## Proof - The Lagrangian is given by $$ \begin{aligned} \mathcal{L}\left(q, \lambda_0\right)&=-\alpha_1 \sum_x p_1(x) \log q(x)-\alpha_2 \sum_x p_2(x) \log q(x) \\ &+\lambda_0\left(\sum_x q(x)-1\right) \end{aligned} $$ - The first-order conditions are $$ \begin{array}{l} \frac{\partial \mathcal{L}}{\partial q(x)}=-\alpha_1 \frac{p_1(x)}{q(x)}-\alpha_2 \frac{p_2(x)}{q(x)}+\lambda_0=0 \\ \frac{\partial \mathcal{L}}{\partial \lambda_0}=\sum_x q(x)-1=0 \end{array} $$ - Hence $$ q(x)=\frac{\alpha_1 p_1(x)+\alpha_2 p_2(x)}{\lambda_0} \Rightarrow \sum_x \frac{\alpha_1 p_1(x)+\alpha_2 p_2(x)}{\lambda_0}=1 \Rightarrow\lambda_0=\alpha_1+\alpha_2 $$ and the result follows immediately. --- ## Theorem A.1. - Let $u$ be the uniform distribution. - Let $p$ be the distribution over variable indices and values defined by $p_{\text {obs }}$ and $p_d$. - Let $H(p, q)$ denote the cross entropy between $p$ and $q$. - We consider the following risk: $$ R(q)=H(p, q)+H(u, q). $$ - Then $q^*=\arg \min _q R(q)$ satisfies the following properties. - For all pairs of adjacent variables $Y_i$ and $Y_{i-1}$, $$ q^*\left(Y_i \mid Y_{i-1}\right)=\lambda p_d\left(Y_i \mid Y_{i-1}\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}, \quad \text{for some $\lambda \in(0,1)$} $$ - For all pairs of non-adjacent variables, $$q^*\left(Y_i \mid Y_j\right)=\frac{1}{|\mathcal{X}|}. $$ --- ## Proof. - We can write the risk as a sum across timesteps. $$ R(q)=\sum_{t=1}^T \mathbb{E}_{p\left(x_{1: t}\right)}\left[-\log q\left(x_t \mid x_{1: t-1}\right)\right]+\sum_{t=1}^T \mathbb{E}_{u\left(x_{1: t}\right)}\left[-\log q\left(x_t \mid x_{1: t-1}\right)\right] $$ - Decompose the fourth term in the left term: $$ \begin{array}{l} \mathbb{E}_{p\left(i_{1: 2}, v_{1: 2}\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] \\ =\mathbb{E}_{p\left(i_{1: 2}, v_1\right)}\left[\mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right]\right] \\ =\sum_{i_1} \sum_{v_1} \sum_{i_2} \mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] p\left(i_{1: 2}, v_1\right) \\ =\sum_{\left(i_1, i_2\right) \in \mathcal{Y}_{\text {obs }}} \sum_{v_1} \mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] p\left(i_{1: 2}, v_1\right) \end{array} $$ - Decompose the fourth term in the right term: $$ \begin{array}{l} \mathbb{E}_{u\left(i_{1: 2}, v_{1: 2}\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] \\ =\mathbb{E}_{u\left(i_{1: 2}, v_1\right)}\left[\mathbb{E}_{u\left(v_2 \mid k_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right]\right] \\ =\sum_{i_1} \sum_{v_1} \sum_{i_2} \mathbb{E}_{u\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] u\left(i_{1: 2}, v_1\right) \end{array} $$ --- - Adjacent pairs: - Suppose $i_2=i+1$ and $i_1=i$ for some $i$. In addition, fix some value for $v_1$: $$ \begin{array}{r} \frac{1}{2} \mathbb{E}_{p\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right)\right] p\left(i_{1: 2}, v_1\right) \\ +\frac{1}{2} \mathbb{E}_{u\left(v_2 \mid i_{1: 2}, v_1\right)}\left[-\log q\left(v_2 \mid i_{1: 2}, v_1\right] u\left(i_{1: 2}, v_1\right)\right. \end{array} $$ - By Proposition 1, the sum is minimized by $$ q^*\left(Y_i \mid Y_j\right)=\lambda_{i, j} \frac{1}{|\mathcal{X}|}+(1-\left.\lambda_{i, j}\right) p\left(Y_i \mid Y_j\right) $$ for $\lambda_{i, j}=\frac{u\left(i_{1: 2}, v_1\right)}{u\left(i_{1: 2}, v_1\right)+p\left(i_{1:2},v_1\right)}$. - Non-adjacent pairs: - The left term is zero since $p(i_1=i,i_2=j) =0$. - Therefore, $q^*\left(Y_i \mid Y_j\right)=\frac{1}{|\mathcal{X}|}$. --- ## Theorem A.2. - We assume that $\sum_{y_i} p\left(Y_i=y_i \mid Y_j=y_j\right)=1=\sum_{y_j} p\left(Y_i=y_i \mid Y_j=y_j\right)$. - For all $y_i, y_j \in \mathcal{X}$ with $|i-j|>1$, $$ \begin{aligned} &\left|\mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right]-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\ & < \left|\hat{q}_D^*\left(Y_i=y_i \mid Y_j=y_j\right)-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \end{aligned} $$ --- <!-- .slide: style="font-size: 20px;" --> ## Proof. - First, we show that $$ \begin{aligned} & \mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right] =\lambda p\left(Y_i=y_i \mid Y_j=y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|} \end{aligned} $$ - We prove this by induction on $|i-j|$. Consider the base case $|i-j|=2$. $$ \begin{aligned} \mathbb{E}& \left[\hat{q}_S^*\left(Y_3=y_3 \mid Y_1=y_1\right)\right] = \mathbb{E} \left[\left(1-\lambda_{3,2}\right) p\left(Y_3=y_3 \mid Y_2=y_2\right)+\lambda_{3,2} \frac{1}{|\mathcal{X}|}\right] \\ = & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right) \mathbb{E}\left[p\left(Y_3=y_3 \mid Y_2=Y_2\right)\right] \\ = & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right) \sum_{y_2} p\left(Y_3=y_3 \mid Y_2=y_2\right) \left[\left(1-\lambda_{2,1}\right) p\left(Y_2=y_2 \mid Y_1=y_1\right)+\lambda_{2,1} \frac{1}{|\mathcal{X}|}\right] \\ = & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right)\left(1-\lambda_{2,1}\right) \sum_{y_2} p\left(Y_3=y_3 \mid Y_2=y_2\right) p\left(Y_2=y_2 \mid Y_1=y_1\right) \\ & +\left(1-\lambda_{3,2}\right) \lambda_{2,1} \sum_{y_2} \frac{1}{|\mathcal{X}|} p\left(Y_3=y_3 \mid Y_2=y_2\right) \\ = & \lambda_{3,2} \frac{1}{|\mathcal{X}|}+\left(1-\lambda_{3,2}\right)\left(1-\lambda_{2,1}\right) p\left(Y_3=y_3 \mid Y_1=y_1\right) +\left(1-\lambda_{3,2}\right) \lambda_{2,1} \frac{1}{|\mathcal{X}|} \\ = & (1-\lambda) \frac{1}{|\mathcal{X}|}+\lambda p\left(Y_3=y_3 \mid Y_1=y_1\right) \quad \text{where $\lambda=\left(1-\lambda_{3,2}\right)\left(1-\lambda_{2,1}\right)$} \end{aligned} $$ --- - For the induction step, we note that the expectations $\mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right]$ and $\mathbb{E}\left[\hat{q}_S^*\left(Y_{i-1}=y_{i-1} \mid Y_j=y_j\right)\right]$ are related as follows $$ \begin{aligned} & \mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right] \\ & =\lambda \frac{1}{|\mathcal{X}|}+(1-\lambda) \sum_{y_{i-1}} p\left(Y_i=y_i \mid Y_{i-1}=y_{i-1}\right) \times \mathbb{E}\left[\hat{q}_S^*\left(Y_{i-1}=y_{i-1} \mid Y_j=y_j\right)\right] \end{aligned} $$ --- - The bias of direct estimator can be computed as $\left|\frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2$. - By the previous results, $\mathbb{E}\left[\hat{q}_S^*\left(Y_i=y_i \mid Y_j=y_j\right)\right]=\lambda p\left(Y_i \mid Y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}$ for some $\lambda \in(0,1)$. $$ \begin{array}{l} \left|\mathbb{E}\left[\hat{q}_S^*\left(Y_i \mid Y_j=y_j\right)\right]-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\ =\left|\lambda p\left(Y_i=y_i \mid Y_j=y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\ =\left|\lambda p\left(Y_i=y_i \mid Y_j=y_j\right)+\frac{1}{|\mathcal{X}|}-\lambda \frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \\ =\left|(\lambda-1) p\left(Y_i=y_i \mid Y_j=y_j\right)+(1-\lambda) \frac{1}{|\mathcal{X}|}\right|^2 \\ =(1-\lambda)^2 \left|\frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2. \\ <\left|\frac{1}{|\mathcal{X}|}-p\left(Y_i=y_i \mid Y_j=y_j\right)\right|^2 \end{array} $$ <!-- ## Reason for Chain-of-Thought - direct prediction is inaccurate for some inferences because the relevant variables are rarely seen together in training. - chain-of-thought reasoning improves estimation by incrementally chaining local statistical dependencies that are observed frequently in training. - The combination of locally structured training data and reasoning with self-generated intermediate variables yields much greater data efficiency than training on data containing all variables. --- -->