<style>
.reveal {
font-size: 24px;
}
</style>
<!-- .slide: style="font-size: 32px;" -->
## Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking
https://arxiv.org/pdf/2403.09629.pdf
---
## Preliminaries
- Chain of Thought(CoT)
- Self-Taught Reasoner(STaR)
---
## Chain of Thought(CoT)
https://arxiv.org/pdf/2201.11903.pdf

---
## Self-Taught Reasoner(STaR)
https://arxiv.org/pdf/2203.14465.pdf

---
<!--
## Abstract
- A generalization of STaR in which LMs learn to generate rationales to explain future text.
- Adress the key challenges:
- the computational cost of generating continuations.
- the fact that the LM does not initially know how to generate or use internal thoughts.
- the need to predict beyond individual next tokens.
- A tokenwise parallel sampling algorithm.
---
-->
## Problem Statement
- Introduce An auxiliary **rationale** variable between each pair of observed tokens of the sequence.
- Optimize a language model with parameters $\theta$ to generate intermediate thoughts
$$
\theta^*=\arg \max _\theta E_x\left[\log p_\theta\left(x_{i: n} \mid x_{0: i}, \text { rationale }_\theta\left(x_{0: i}\right)\right)\right]
$$
- In principle, this provides no advantage over an **optimal language model** that already correctly models the language’s distribution over strings.
- In practice, extensive prior work (Nye et al., 2021; Zelikman et al., 2022; Wei et al., 2022b) has shown that language models benefit from intermediate rationales on reasoning tasks.
- Some work explain the effects of chain-of-thought reasoning, namely attributing it to “locality of experience (Prystawski et al., 2024)
---
## Three main steps
1) Think: Parallel rationale generation
2) Talk: Mixing post-rationale and base predictions
3) Learn: Optimizing rationale generation

---
## 1. Parallel rationale generation
- In parallel across $n$ tokens $x_i$ in an input sequence $x_{0: n}$, generate $r$ rationales of length $t: c_i=\left(c_{i 1}, \ldots, c_{i t}\right)$, resulting in $n \times r$ rationale candidates.
- Insert learned $<\mid$ startofthought $\mid>$ and $<\mid$ endofthought $\mid>$ tokens to mark each rationale's start and end.
---
## 2. Mixing post-rationale and base predictions
- Train a "mixing head" producing a weight for how much the **post-rationale next-token predicted logits** should be mixed with **base language model predicted logits**.
- Eases the distribution shift early in finetuning, due to introducing rationales.
---
## 3. Optimizing rationale generation
- Optimize the rationale generation parameters (start/end tokens and LM weights) to increase the likelihood of rationales that make future text more probable.
- Use REINFORCE to provide a learning signal to rationales based on their impact on future-token prediction.
<!-- - To reduce variance, we apply a teacher-forcing trick to include in the loss the likelihood of predicting not only the token after the thought but also later tokens. -->
---
## Objective
- $p_{j: j+n_{\text {true }}}^{\text {talk }}(X_{j+1: j+n_{\text {true }}+1})$ : the log-likelihood of the $n_{\text {true }}$ true next tokens given previous observed tokens and a particular rationale .
- The reward $r_j$ for each rationale $T_j$ is the difference between $p_{j: j+n_{\text {true }}}^{\text {talk }}$ and the average rationales $\bar{p}_{j: j+n_{\text {true }}}^{\text {talk }}$.
$$
r_j=\log p_{j: j+n_{\text {true }}}^{\text {talk }}\left(X_{j+1: j+n_{\text {true }}+1}\right)-\log \bar{p}_{j: j+n_{\text {true }}}^\text { talk }\left(X_{j+1: j+n_{\text {true }}+1}\right)
$$
- REINFORCE loss term:
$$
\nabla_\theta \mathcal{L}_j^{\text {REINFORCE }}=-r_j \cdot \nabla_\theta \log p_\theta\left(T_j \mid\left[X_{i j} ;<\mid \text { startof thought } \mid>\right]\right)
$$
---

---
## Experiments and Results

---
## Example
