<style> .reveal { font-size: 24px; } </style> <!-- .slide: style="font-size: 32px;" --> ## Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking https://arxiv.org/pdf/2403.09629.pdf --- ## Preliminaries - Chain of Thought(CoT) - Self-Taught Reasoner(STaR) --- ## Chain of Thought(CoT) https://arxiv.org/pdf/2201.11903.pdf ![圖片](https://hackmd.io/_uploads/BkpvKe6R6.png) --- ## Self-Taught Reasoner(STaR) https://arxiv.org/pdf/2203.14465.pdf ![圖片](https://hackmd.io/_uploads/H15f8WTCT.png) --- <!-- ## Abstract - A generalization of STaR in which LMs learn to generate rationales to explain future text. - Adress the key challenges: - the computational cost of generating continuations. - the fact that the LM does not initially know how to generate or use internal thoughts. - the need to predict beyond individual next tokens. - A tokenwise parallel sampling algorithm. --- --> ## Problem Statement - Introduce An auxiliary **rationale** variable between each pair of observed tokens of the sequence. - Optimize a language model with parameters $\theta$ to generate intermediate thoughts $$ \theta^*=\arg \max _\theta E_x\left[\log p_\theta\left(x_{i: n} \mid x_{0: i}, \text { rationale }_\theta\left(x_{0: i}\right)\right)\right] $$ - In principle, this provides no advantage over an **optimal language model** that already correctly models the language’s distribution over strings. - In practice, extensive prior work (Nye et al., 2021; Zelikman et al., 2022; Wei et al., 2022b) has shown that language models benefit from intermediate rationales on reasoning tasks. - Some work explain the effects of chain-of-thought reasoning, namely attributing it to “locality of experience (Prystawski et al., 2024) --- ## Three main steps 1) Think: Parallel rationale generation 2) Talk: Mixing post-rationale and base predictions 3) Learn: Optimizing rationale generation ![圖片](https://hackmd.io/_uploads/SkyWm4s0T.png) --- ## 1. Parallel rationale generation - In parallel across $n$ tokens $x_i$ in an input sequence $x_{0: n}$, generate $r$ rationales of length $t: c_i=\left(c_{i 1}, \ldots, c_{i t}\right)$, resulting in $n \times r$ rationale candidates. - Insert learned $<\mid$ startofthought $\mid>$ and $<\mid$ endofthought $\mid>$ tokens to mark each rationale's start and end. --- ## 2. Mixing post-rationale and base predictions - Train a "mixing head" producing a weight for how much the **post-rationale next-token predicted logits** should be mixed with **base language model predicted logits**. - Eases the distribution shift early in finetuning, due to introducing rationales. --- ## 3. Optimizing rationale generation - Optimize the rationale generation parameters (start/end tokens and LM weights) to increase the likelihood of rationales that make future text more probable. - Use REINFORCE to provide a learning signal to rationales based on their impact on future-token prediction. <!-- - To reduce variance, we apply a teacher-forcing trick to include in the loss the likelihood of predicting not only the token after the thought but also later tokens. --> --- ## Objective - $p_{j: j+n_{\text {true }}}^{\text {talk }}(X_{j+1: j+n_{\text {true }}+1})$ : the log-likelihood of the $n_{\text {true }}$ true next tokens given previous observed tokens and a particular rationale . - The reward $r_j$ for each rationale $T_j$ is the difference between $p_{j: j+n_{\text {true }}}^{\text {talk }}$ and the average rationales $\bar{p}_{j: j+n_{\text {true }}}^{\text {talk }}$. $$ r_j=\log p_{j: j+n_{\text {true }}}^{\text {talk }}\left(X_{j+1: j+n_{\text {true }}+1}\right)-\log \bar{p}_{j: j+n_{\text {true }}}^\text { talk }\left(X_{j+1: j+n_{\text {true }}+1}\right) $$ - REINFORCE loss term: $$ \nabla_\theta \mathcal{L}_j^{\text {REINFORCE }}=-r_j \cdot \nabla_\theta \log p_\theta\left(T_j \mid\left[X_{i j} ;<\mid \text { startof thought } \mid>\right]\right) $$ --- ![圖片](https://hackmd.io/_uploads/r1vkh4jC6.png) --- ## Experiments and Results ![圖片](https://hackmd.io/_uploads/HyVRMEa0p.png) --- ## Example ![圖片](https://hackmd.io/_uploads/rJgnmNp0p.png)