---
# System prepended metadata

title: Quiet-STaR Language Models Can Teach Themselves to Think Before Speaking
tags: ["\_presentation", slides]

---

<style>
.reveal {
  font-size: 24px;
}
</style>
<!-- .slide: style="font-size: 32px;" -->

## Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking
https://arxiv.org/pdf/2403.09629.pdf


---

## Preliminaries

- Chain of Thought(CoT)
- Self-Taught Reasoner(STaR)

---

## Chain of Thought(CoT)
https://arxiv.org/pdf/2201.11903.pdf

![圖片](https://hackmd.io/_uploads/BkpvKe6R6.png)


---

## Self-Taught Reasoner(STaR)
https://arxiv.org/pdf/2203.14465.pdf
![圖片](https://hackmd.io/_uploads/H15f8WTCT.png)


---

<!-- 
## Abstract

- A generalization of STaR in which LMs learn to generate rationales to explain future text. 
- Adress the key challenges: 
    - the computational cost of generating continuations. 
    - the fact that the LM does not initially know how to generate or use internal thoughts. 
    - the need to predict beyond individual next tokens. 
- A tokenwise parallel sampling algorithm. 


---
 -->


## Problem Statement
- Introduce An auxiliary **rationale** variable between each pair of observed tokens of the sequence. 
- Optimize a language model with parameters $\theta$ to generate intermediate thoughts 
$$
\theta^*=\arg \max _\theta E_x\left[\log p_\theta\left(x_{i: n} \mid x_{0: i}, \text { rationale }_\theta\left(x_{0: i}\right)\right)\right]
$$
- In principle, this provides no advantage over an **optimal language model** that already correctly models the language’s distribution over strings. 
- In practice, extensive prior work (Nye et al., 2021; Zelikman et al., 2022; Wei et al., 2022b) has shown that language models benefit from intermediate rationales on reasoning tasks. 
- Some work explain the effects of chain-of-thought reasoning, namely attributing it to “locality of experience (Prystawski et al., 2024)

---

## Three main steps 
1) Think: Parallel rationale generation 
2) Talk: Mixing post-rationale and base predictions
3) Learn: Optimizing rationale generation

![圖片](https://hackmd.io/_uploads/SkyWm4s0T.png)

---

## 1. Parallel rationale generation

- In parallel across $n$ tokens $x_i$ in an input sequence $x_{0: n}$,  generate $r$ rationales of length $t: c_i=\left(c_{i 1}, \ldots, c_{i t}\right)$, resulting in $n \times r$ rationale candidates.
- Insert learned $<\mid$ startofthought $\mid>$ and $<\mid$ endofthought $\mid>$ tokens to mark each rationale's start and end.


---

## 2. Mixing post-rationale and base predictions

- Train a "mixing head" producing a weight for how much the **post-rationale next-token predicted logits** should be mixed with **base language model predicted logits**. 
- Eases the distribution shift early in finetuning, due to introducing rationales.


---


## 3. Optimizing rationale generation 

- Optimize the rationale generation parameters (start/end tokens and LM weights) to increase the likelihood of rationales that make future text more probable. 
- Use REINFORCE to provide a learning signal to rationales based on their impact on future-token prediction. 
<!-- - To reduce variance, we apply a teacher-forcing trick to include in the loss the likelihood of predicting not only the token after the thought but also later tokens. -->

---

## Objective

- $p_{j: j+n_{\text {true }}}^{\text {talk }}(X_{j+1: j+n_{\text {true }}+1})$ : the log-likelihood of the $n_{\text {true }}$ true next tokens given previous observed tokens and a particular rationale  . 
- The reward $r_j$ for each rationale $T_j$ is the difference between $p_{j: j+n_{\text {true }}}^{\text {talk }}$ and the average rationales $\bar{p}_{j: j+n_{\text {true }}}^{\text {talk }}$.
$$
r_j=\log p_{j: j+n_{\text {true }}}^{\text {talk }}\left(X_{j+1: j+n_{\text {true }}+1}\right)-\log \bar{p}_{j: j+n_{\text {true }}}^\text { talk }\left(X_{j+1: j+n_{\text {true }}+1}\right)
$$
- REINFORCE loss term:
$$
\nabla_\theta \mathcal{L}_j^{\text {REINFORCE }}=-r_j \cdot \nabla_\theta \log p_\theta\left(T_j \mid\left[X_{i j} ;<\mid \text { startof thought } \mid>\right]\right)
$$

---


![圖片](https://hackmd.io/_uploads/r1vkh4jC6.png)

---

## Experiments and Results

![圖片](https://hackmd.io/_uploads/HyVRMEa0p.png)


---

## Example

![圖片](https://hackmd.io/_uploads/rJgnmNp0p.png)
