# [SERI MATS 2023] Stochastic n-recurrence: a toy model of probabilistic inference The proposed experiment will enable studying a toy problem of probabilistic inference implemented through in-context learning. [In-context lerning](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=xN5Cp0bPFOnQvWGmALV2z4Rb) refers to a situation where a language model uses distant tokens to predict the next token. It seems pretty straightforward that a language model performing probabilistic inference would do so by in-context learning. After all, probabilistic learning requires taking into account novel information (which, in this case, is just previous tokens) in order to update the prior probability distribution. Moreover, it has been shown that induction circuits ([Elhage et al., 2022](https://transformer-circuits.pub/2021/framework/index.html#induction-heads)) are essential for in-context learning in all GPT-style^[By "GPT-style models" I mean language models based on the transformer decoder-only architecture.] models studied, up to the size of 13 billion parameters ([Olsson et al., 2022](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)). This result points out induction circuits and in-context learning as promising entrypoints for studying more complex behaviors and capabilities in GPT-style models. That apparent generality of induction circuits (from 2-layer attention-only transformers up to 13 billion parameter GPT-style models) gives some hope that understanding probabilistic inference in toy GPT-style models may also generalize to larger models. This would be important and informative, even if this is only partial and even if it is very limited, it may still serve as a starting point for understanding how bigger models reason probabilistically. Understanding probabilistic inference is likely important for safety of powerful machine learning models. Adequate understanding of this aspect of cognition might enable interventions in the model's belief or representation dynamics, influencing the probabilities it assigns to particular scenarios. For example, we could lower the model's expected utility[Here, I am using the term "utility" loosely to mean "satisfaction of one's own preferences".] assigned to scenarios which involve deceiving humans,^[Admittedly, this would also require the ability to recognize the thought of deception, but having that, lowering the utility of that option would be *a way* to increase that model's alignment.] or we could make it believe that it is in the training distribution, even given evidence that it has been deployed to the real world, preventing the sharp left turn. ## Experimental setup We will train the model to the point of grokking ([Power et al., 2022](https://arxiv.org/abs/2201.02177)), since we want the model to internalize the rule governing the data distribution. I suggest starting with a simple task, which I call "deterministic n-recurrence", described below.^[Loosely inspired by [the n-back task](https://en.wikipedia.org/wiki/N-back) used in psychological research to study working memory and related aspects of cognition.] Once we get a model to grok it, we move on to the harder "stochastic n-recurrence", which differs from deterministic n-recurrence in that it has two sources of randomness, one coming from the changing rules (see $p_\text{transition}$) and the other coming from the indeterminacy of these rules (see $p_\text{noise}$). ### Deterministic n-recurrence We use synthetic data consisting of procedurally generated sequences of tokens. A sequence is specified by two parameters, the interval $n$ and the target token $t$. All tokens are randomly sampled from the vocabulary, except $t$, which recurs at every $n$-th position (hence the name of the task). For example, if $n=5$ and $t=3$, then a part of the sequence may look like the following: $$ … 21 {\color{green}3} 9871 {\color{green}3} 1095 {\color{green}3} 121 … $$ The positions at which token $t$ appears are called *target positions*. The model is going to be fed sequences with different target tokens $t$ and different values of $n$. Once the model groks the task, it should be able to easily infer $t$ and $n$ of any valid sequence given to it. ### Stochastic n-recurrence Stochastic n-recurrence elaborates on deterministic n-recurrence by adding two sources of stochasticity. First, rules become non-deterministic, such that there is some probability ($p_\text{noise}$) that a token other than $t$ (chosen randomly) appears in the target position. We call that token the *fake target token* and denote it $f$. Second, the rules are changing. At every position, there is some probability ($p_\text{transition}$) that the *true* $t$ and $n$ change after that position, which looks like that sequence ending and another sequence (with a different choice of $t$ and/or $n$) continuing afterward. Shortly after this transition happens, it is indistinguishable from the stochastic noise. However, as the sequence continues in a way incongruent with the previous in-context-learned rules, the more the model gathers evidence that should make it update toward a new choice of $n$ and $t$. Normatively speaking,^[That is, for the model to well approximate ideal probabilistic inference.] the model's "in-context learning rate" (how quickly it updates upon seeing anomalies in the sequence) should be related to the values of $p_\text{noise}$ and $p_\text{transition}$ that characterized its training distribution. However, the context length is also likely to play a role, not as much from the normative perspective, but due to the fact that models with shorter context windows lose sight of previous tokens more quickly, which forces them to attend to more recent ones. Other model parameters are also likely to play a role, such as dimensions of the residual stream, attention head dimension, MLP dimension, number of attention heads, number of layers, vocabulary size, or even the choice of activation function and normalization method can play a role. Thus, varying them and measuring how this variation translates to behavior and performance is also an interesting direction of investigation. ### Experimental procedure We start by finding the minimal architecture capable of grokking deterministic n-recurrence for some initial choice of: (1) the range of $n$ (from which $n$ is sampled), (2) context length, (3) $d_\text{vocab}$, since these are the three parameters of the task that impact its difficulty. Context length should be at least as big as three times the maximum value of $n$, since otherwise it would likely hinder the task. Once that minimal architecture is found, we try training it on stochastic n-recurrence, some initial choice of $p_\text{noise}$ and $p_\text{transition}$ (tentatively, we could start by setting both to $1/5n$), and the same task parameters as the previous deterministic n-recurrence task. We try training it both from scratch and by fine-tuning the model pre-trained on the deterministic task. Most likely, it will fail to grok it, after which we gradually increase the model size in order to find a minimal architecture capable of grokking that task. Having this baseline of two minimal architectures established, we can move on to reverse-engineering (some aspect of) the circuit(s) underlying the models' capacity to perform both the simpler and the more complex version of the task. ## Addenda This section lists some other directions this research could take. ### Other variable task parameters - Does the target token $t$ occurs only in the target position or can also occur outside of it? - Does the fake target token $f$ is sampled uniformly from the vocabulary, or only from the tokens that appear in the most recent context (e.g., last $3n$ tokens)? - Do rule transitions in the stochastic n-recurrence task change both $n$ and $t$, only $n$ or only $t$? For each of the above, we can see whether varying that parameter meaningfully impacts the difficulty of the task, training dynamics, and, most importantly, the algorithm learned by the network and the circuitry implementing it. ### What else could be assessed? - When a model updates from one choice of rules ($n$ and $t$) to another one, is it a relatively discrete phase change or is there some intermediate period of uncertainty? - If the latter, can we derive the probability distribution over values of $n$ and $t$ from the model's activations? - Does discreteness/continuity of the update depend in some interesting way on particular task parameters, model parameters, or even the sequence itself that is being fed to the model? - Do the progress measures for grokking developed by [Nanda et al. (2023)](https://arxiv.org/abs/2301.05217) adequately capture training dynamics in the models trained in this task? - Is it possible to read the model's prior probability distribution over $n$ and $t$ *statically*, i.e., from its weights alone? - Can we observe how manipulating the probability distributions over $n$ and $t$ in the task (which are uniform by default) impacts that prior represented in some way in the weights? - How good is pre-training a model on deterministic n-recurrence if we want to fine-tune for grokking stochastic n-recurrence? - What failure modes do models trained on the deterministic version of the task exhibit when given the stochastic version? Do they tell us something interesting about their learned algorithms? - Can we "edit" the model (either its weights or a small set of activations) to "believe" that $n$ and $t$ describing the current sequence have particular values. (For example, see Meng et al., [2022a](https://rome.baulab.info/), [2022b](https://arxiv.org/abs/2210.07229).) ## References Elhage, N., Nanda, N., Olsson, C., Henighan, T., Joseph, N., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., DasSarma, N., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., … Olah, C. (2021). [A mathematical framework for transformer circuits.](https://transformer-circuits.pub/2021/framework/index.html) Transformer Circuits Thread. Meng, K., Bau, D., Andonian, A., & Belinkov, Y. (2022a). [Locating and editing factual associations in GPT.](https://rome.baulab.info/) Advances in Neural Information Processing Systems, 36. Meng, K., Sharma, A. S., Andonian, A., Belinkov, Y., & Bau, D. (2022b). [Mass-editing memory in a transformer.](https://arxiv.org/abs/2210.07229) Nanda, N., Chan, L., Lieberum, T., Smith, J., & Steinhardt, J. (2023). [Progress measures for grokking via mechanistic interpretability.](https://arxiv.org/abs/2301.05217) Olsson, C., Elhage, N., Nanda, N., Joseph, N., DasSarma, N., Henighan, T., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Johnston, S., Jones, A., Kernion, J., Lovitt, L., … Olah, C. (2022). [In-context learning and induction heads.](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)) Transformer Circuits Thread. Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022). [Grokking: Generalization beyond overfitting on small algorithmic datasets.](https://arxiv.org/abs/2201.02177) ## Footnotes [^1]: