<style>
.reveal {
font-size: 24px;
}
code[class*="language-"], pre[class*="language-"]{
font-size: 20px;
}
</style>
# Self-Improvement in Language Models: The Sharpening Mechanism
https://arxiv.org/pdf/2412.01951
---
### Abstract
- A new perspective on the capabilities of self-improvement through **sharpening**.
- Analyze two natural families of self-improvement algorithms based on **SFT** and **RLHF**
- Empirically validate the sharpening mechanism via inference-time and amortization experiments
---
### The Sharpening Mechanism
- Consider a learner with access to a base model $\pi_{\text {base }}: \mathcal{X} \rightarrow \Delta(\mathcal{Y})$ representing a conditional distribution that maps a prompt $x \in \mathcal{X}$ to a distribution over responses (i.e., $\pi_{\text {base }}(y \mid x)$).
- $\pi_{\text {base }}$ has already been trained in some manner
- with the key feature being that $\pi_{\text {base }}$ is a good verifier
- as measured by some self-reward function $r_{\text {self }}\left(y \mid x ; \pi_{\text {base }}\right)$ measuring model certainty.
- **Sharpening**
- Any process that tilts $\pi_{\text {base }}$ toward responses that are more certain in the sense that they enjoy greater self-reward $r_{\text {self }}$.
- That is, a sharpened model $\widehat{\pi}$ is one that (approximately) maximizes the self-reward:
$$
\widehat{\pi}(x) \approx \underset{y \in \mathcal{Y}}{\arg \max } r_{\text {self }}\left(y \mid x ; \pi_{\text {base }}\right)
$$
---
### A Statistical Framework for Sharpening
---
### Maximum-likelihood sharpening objective
- Maximize $r_{\text {self }}(y \mid x)$ using conditional samples $y \sim \pi_{\text {base }}(\cdot \mid x)$ from the base model.
$$
r_{\text {self }}(y \mid x):=\log \pi_{\text {base }}(y \mid x),
$$
- Given access to a sampling oracle that can sample $y \sim \pi_{\text {base }}(\cdot \mid x)$.
- Can we efficiently amortize maximum likelihood inference (optimization) for a conditional distribution $\pi_{\text {base }}(y \mid x)$ ?
- Let
$$
\mathbf{y}^{\star}(x):=\underset{y \in \mathcal{Y}}{\arg \max } \log \pi_{\text {base }}(y \mid x).
$$
- We interpret $\mathbf{y}^{\star}(x) \subset \mathcal{Y}$ as a set to accommodate non-unique maximizers, and will write $y^{\star}(x)$ to indicate a unique maximizer.
---
### Definition 3.1 (Sharpened model).
- We say that a model $\widehat{\pi}$ is $(\epsilon, \delta)$-sharpened relative to $\pi_{\text {base }}$ if
$$
\mathbb{P}_{x \sim \mu}\left[\widehat{\pi}\left(\mathbf{y}^{\star}(x) \mid x\right) \geq 1-\delta\right] \geq 1-\epsilon
$$
- That is, an $(\epsilon, \delta)$-sharpened model places
- at least $1-\delta$ mass on arg-max responses
- except an $\epsilon$-fraction of prompts under $\mu$.
---
### Maximum-likelihood sharpening for autoregressive models.
- Autoregressive setting in which $\mathcal{Y}=\mathcal{V}^H$ for a vocabulary space $\mathcal{V}$ and sequence length $H$, and where $\pi_{\text {base }}$ has the autoregressive structure $\pi_{\text {base }}\left(y_{1: H} \mid x\right)=\prod_{h=1}^H \pi_{\text {base }, h}\left(y_h \mid y_{1: h-1}, x\right)$ for $y=y_{1: H} \in \mathcal{Y}$.
- When the response $y=\left(y_1, \ldots, y_H\right) \in \mathcal{Y}=\mathcal{V}^H$ is a sequence of tokens, the maximum-likelihood sharpening objective is:
$$
\underset{y_{1: H}}{\arg \max } \log \pi_{\text {base }}\left(y_{1: H} \mid x\right)
$$
---
### Proposition 3.1(Greedy decoding succeeds for sharpened policies).
- Let $\pi=\pi_{1: H}$ be an autoregressive model defined over response space $\mathcal{Y}=\mathcal{V}^H$.
- For a given prompt $x \in \mathcal{X}$, if
- $\mathbb{y}^{\star}(x)=\left\{y^{\star}(x)\right\}$ is a singleton
- $\pi\left(y^{\star}(x) \mid x\right)>1 / 2$
- Then the grady decoding strategy guarantees that $\widehat{y}=y^{\star}(x)$ by selecting
$$
\widehat{y}_h=\underset{y_h \in \mathcal{V}}{\arg \max } \pi_h\left(y_h \mid \widehat{y}_1, \ldots, \widehat{y}_{h-1}, x\right).
$$
---
### Sample Complexity Framework
---
### Definition 3.2 (Sample-and-evaluate framework).
- In the sample-and-evaluate framework, the algorithm designer access $\pi_{\text {base }}$ only through sample-and-evaluate queries:
- The learner is allowed to sample $n$ prompts $x \sim \mu$.
- For each prompt $x$, they can sample $N$ responses $y_1, y_2, \ldots y_N \sim \pi_{\text {base }}(\cdot \mid x)$ and observe the likelihood $\pi_{\text {base }}\left(y_i \mid x\right)$ for each such response.
- The efficiency, or sample complexity, of the algorithm is measured through the total number of sample-and-evaluate queries $m:=n \cdot N$.
---
- **SFT-Sharpening** and **RLHF-Sharpening** can learn an $(\epsilon, \delta)$-sharpened model with sample complexity
$$
m=\operatorname{poly}\left(\epsilon^{-1}, \delta^{-1}, C_{\text {prob }}\right)
$$
where $C_{\text {prob }}$ is a potentially problem-dependent constant.
---
### Fundamental Limits
- The performance of any sharpening algorithm based on sampling should depend on how well the base model $\pi_{\text {base }}$ covers the arg-max response $y^{\star}(x)$.
- Coverage coefficient:
$$
C_{\mathrm{cov}}=\mathbb{E}_{x \sim \mu}\left[\frac{1}{\pi_{\text {base }}\left(\boldsymbol{y}^{\star}(x) \mid x\right)}\right]
$$
- For a model $\pi$, we define $\boldsymbol{y}^\pi(x)=\arg \max _{y \in \mathcal{Y}} \pi(y \mid x)$ and $C_{\text {cov }}(\pi)=\mathbb{E}_{x \sim \mu}\left[\frac{1}{\pi\left(\boldsymbol{y}^\pi(x) \mid x\right)}\right]$.
---
### Theorem 3.1 (Lower bound for sharpening).
- Fix an integer $d \geq 1$ and parameters $\epsilon \in(0,1)$ and $C \geq 1$. There exists a class of models $\Pi$ such that
- (i) $\log |\Pi| \asymp d\left(1+\log \left(C \epsilon^{-1}\right)\right)$,
- (ii) $\sup _{\pi \in \Pi} C_{\operatorname{cov}}(\pi) \lesssim C$
- (iii) $\boldsymbol{y}^\pi(x)$ is a singleton for all $\pi \in \Pi, x \in \mathcal{X}$.
- Any sharpening algorithm $\widehat{\pi}$ that achieves $\mathbb{E}\left[\mathbb{P}_{x \sim \mu}\left[\widehat{\pi}\left(\boldsymbol{y}^{\pi_{\text {base }}}(x) \mid x\right)>1 / 2\right]\right] \geq$ $1-\epsilon$ for all $\pi_{\text {base }} \in \Pi$ must collect a total number of samples $m=n \cdot N$ at least
$$
m \gtrsim \frac{C \log |\Pi|}{\epsilon^2 \cdot\left(1+\log \left(C \epsilon^{-1}\right)\right)}
$$
---
### Analysis of Sharpening Algorithms
---
### Analysis of SFT-Sharpening
- The SFT-Sharpening algorithm takes the form
$$
\widehat{\pi}^{\mathrm{BON}}=\underset{\pi \in \Pi}{\arg \max } \sum_{i=1}^n \log \pi_{\mathrm{base}}\left(y_i^{\mathrm{BON}} \mid x_i\right),
$$
where $y_i^{\mathrm{BON}}=\arg \max _{j \in[N]}\left\{\log \pi_{\text {base }}\left(y_{i, j} \mid x_i\right)\right\}$ for $y_{i, 1}, \ldots, y_{i, N} \sim \pi_{\text {base }}\left(\cdot \mid x_i\right)$.
- Let $\pi_N^{\mathrm{BON}}(x)$ be the distribution of the random variable $y_N^{\mathrm{BON}}(x) \sim \arg \max \left\{\log \pi_{\text {base }}\left(y_i \mid x\right) \mid y_1, \ldots, y_N \sim \pi_{\text {base }}(x)\right\}$.
- **Assumption 4.1.** The model class $\Pi$ satisfies $\pi_N^{\mathrm{BON}} \in \Pi$.
---
### Theorem 4.1 (Sample complexity of SFT-Sharpening).
- Suppose $N=N^{\star} \log \left(2 \delta^{-1}\right)$ for a parameter $N^{\star} \in \mathbb{N}$.
- If Assumption 4. 1 holds, then for any $n \in \mathbb{N}$, SFT-Sharpening produces $\widehat{\pi}$ such that with probability at least $1-\rho$,
$$
\mathbb{P}_{x \sim \mu}\left[\widehat{\pi}\left(\boldsymbol{y}^{\star}(x) \mid x\right) \leq 1-\delta\right] \lesssim \frac{1}{\delta} \cdot \frac{\log \left(|\Pi| \rho^{-1}\right)}{n}+\frac{C_{\mathrm{cov}}}{N^{\star}}
$$
- In particular, by setting $n=c \cdot \frac{\log |\Pi|}{\delta \epsilon}$ and $N^{\star}=c \cdot \frac{C_{c o v}}{\epsilon}$ for $c>0$, we have $\mathbb{P}_{x \sim \mu}\left[\widehat{\pi}\left(\boldsymbol{y}^{\star}(x) \mid x\right) \leq 1-\delta\right] \leq \epsilon$, and
$$
m=O\left(\frac{C_{\mathrm{cov}} \log \left(|\Pi| \rho^{-1}\right) \log \left(\delta^{-1}\right)}{\delta \epsilon^2}\right)
$$
---
### Analysis of RLHF-Sharpening
- The RL objective used by RLHF-Sharpening takes the form
$$
\widehat{\pi} \approx \underset{\pi \in \Pi}{\arg \max }\left\{\mathbb{E}_\pi\left[\log \pi_{\text {base }}(y \mid x)\right]-\beta D_{\mathrm{KL}}\left(\pi \| \pi_{\text {base }}\right)\right\}
$$
- The exact optimizer $\pi_\beta^{\star}=\arg \max _{\pi \in \Pi}\left\{\mathbb{E}_\pi\left[\log \pi_{\text {base }}(y \mid x)\right]-\beta D_{\mathrm{KL}}\left(\pi \| \pi_{\text {base }}\right)\right\}$
- We consider the algorithm that solves
$$
\begin{multline}
\widehat{\pi} \in \underset{\pi \in \Pi}{\arg \min } \sum_{\left(x, y, y^{\prime}\right) \in \mathcal{D}_{\text {pref }}}\Bigg(\beta \log \frac{\pi(y \mid x)}{\pi_{\text {base }}(y \mid x)}-\beta \log \frac{\pi\left(y^{\prime} \mid x\right)}{\pi_{\text {base }}\left(y^{\prime} \mid x\right)} \\
-\left(\log \pi_{\text {base }}(y \mid x)-\log \pi_{\text {base }}\left(y^{\prime} \mid x\right)\right)\Bigg)^2
\end{multline}
$$
- We define two concentrability coefficients for a model $\pi$ :
$$
\mathcal{C}_\pi=\mathbb{E}_\pi\left[\frac{\pi(y \mid x)}{\pi_{\text {base }}(y \mid x)}\right], \quad \text { and } \quad \mathcal{C}_{\pi / \pi^{\prime} ; \beta}:=\mathbb{E}_\pi\left[\left(\frac{\pi(y \mid x)}{\pi^{\prime}(y \mid x)}\right)^\beta\right]
$$
---
- **Assumption 4.2 (Margin).** For a margin parameter $\gamma_{\text {margin }}>0$, the base model $\pi_{\text {base }}$ satisfies
$$
\max _{y \in \mathcal{Y}} \pi_{\text {base }}(y \mid x) \geq\left(1+\gamma_{\text {margin }}\right) \cdot \pi_{\text {base }}\left(y^{\prime} \mid x\right) \quad \forall y^{\prime} \notin \boldsymbol{y}^{\star}(x), \quad \forall x \in \operatorname{supp}(\mu) .
$$
- **Assumption 4.3 (Realizability).** The model class $\Pi$ satisfies $\pi_\beta^{\star} \in \Pi.$
- **Assumption 4.4 (Concentrability).** All $\pi \in \Pi$ satisfy $\mathcal{C}_\pi \leq C_{\text {conc }}$ for a parameter $C_{\text {conc }} \geq C_{\text {cov }}$, and $\mathcal{C}_{\pi_{\text {base }} / \pi ; \beta} \leq C_{\text {loss }}$ for a parameter $C_{\text {loss }} \geq|\mathcal{Y}|$.
---
### Lemma 4.1.
- The model $\pi_\beta^{\star}$ satisfies $\mathcal{C}_{\pi_\beta^{\star}} \leq C_{\text {cov }}$ and $\mathcal{C}_{\pi_{\text {base }} / \pi_\beta^{\star} ; \beta} \leq|\mathcal{Y}|$.
---
### Theorem 4.2.
- Set $\beta \lesssim \gamma_{\operatorname{margin}} \delta \epsilon$, and suppose that Assumptions 4.2 to 4.4 hold with parameters $C_{\text {conc }}, C_{\text {loss }}$, and $\gamma_{\text {margin }}>0$. The DPO algorithm ensures that with probability at least $1-\rho, \mathbb{P}_{x \sim \mu}\left[\widehat{\pi}\left(\boldsymbol{y}^{\star}(x) \mid x\right) \leq 1-\delta\right] \leq \epsilon$, and
$$
m=\widetilde{O}\left(\frac{C_{\mathrm{conc}} \log ^3\left(C_{\mathrm{loss}}|\Pi| \rho^{-1}\right)}{\gamma_{\operatorname{margin}}^2 \delta^2 \epsilon^2}\right)
$$
{"contributors":"[{\"id\":\"532a191a-ebe1-47ad-885a-d93271968536\",\"add\":18644,\"del\":7594}]","title":"Self-Improvement in Language Models: The Sharpening Mechanism","description":"The Sharpening Mechanism"}