Notation: $x$: prompt (includes few shot examples) $z$: cot part of completion $y$: answer part of completion $c = (y,z)$: completion, cot + answer together $s(y, c)$: indicator of correctness of answer included in completion $c$ $p()$: original model $p_{new}()$ probability under new model after gradient update Here is one algorithm, using importance weighting: 1. we sample many completions $c_i\sim p$ from our original model 2. we evaluete $\log p(c_i\vert x)$ 3. We update the model to obtain $p_{new}$ 4. We evaluate $\log p_{new}(c_i\vert x)$ for each of our original samples. 5. We estimate the difference in correctness using the following formula: \begin{align} &\mathbb{E}_{c\sim p_{new}(c\vert x)} s(y, c) - \mathbb{E}_{c\sim p(c\vert x)} s(y, c) = \\ &\mathbb{E}_{c\sim p(c\vert x)} \frac{p_{new}(c)}{p(c)}s(y, c) - \mathbb{E}_{c\sim p(c\vert x)} s(y, c) =\\ &\mathbb{E}_{c\sim p(c\vert x)} \left( \frac{p_{new}(c\vert x)}{p(c\vert x)} - 1\right)s(y, c) \approx \\ &\frac{1}{N} \sum_{i=1}^N \left( \frac{p_{new}(c_i\vert x)}{p(c_i\vert x)} - 1\right)s(y, c_i) \end{align} If $s$ returns a $0$ when the answer is incorrect, the estimator may have a very high variance, as it ignores all the unsuccessful prompts. So it's perhaps better for $s$ to be $-1$ when things are incorrect, and then transform the result back to a probability of correctness. In theory, you should be able to calculate $\log p$ and $\log p_{new}$ of any string, because this log probability is used in the gradient descent part. to extract $p(c\vert x)$ one can do $\log p(x, c) - \log p(x)$.