owned this note
owned this note
Published
Linked with GitHub
# Understanding Search in Transformers
# Summary
Transformers are capable of a huge variety of tasks, and for the most part we know very little about how. In particular, understanding how an AI system implements search is [probably very important for AI safety](https://www.alignmentforum.org/posts/6mysMAqvo9giHC4iX/what-s-general-purpose-search-and-why-might-we-expect-to-see). In this project, we will aim to:
- gain a mechanistic understanding of how transformers implement search for toy tasks
- explore how the search process can be retargeted, ensuring that the AI system is aligned to human preferences
- attempt to find scaling laws for search-oriented tasks and compare them to existing scaling laws
<!-- It's not well understood how transformers and Large Language Models are capable of performing tasks. Additionally, it is considered critically important to understand how agentic AI systems perform internal search, in order to determine which future states will be reached and via what paths. This project aims to train or finetune transformers on toy tasks, and then perform mechanistic interpretability work in order to understand:
# Introduction
<!-- If the project succeeds, how would this be useful for reducing risk from AGI/TAI?
What are the steps you need to complete to finish this project?
What’s the first step?
What can go wrong, and what’s the backup plan if that happens?
What’s the most ambitious version of this project?
What’s the least ambitious version of this project?
What’s your scope? What research do you not focus on? -->
# Introduction
> I'm now realizing that many folks assume there must be somebody on earth who knows how LLMs work. This is false. Nobody knows how they work or how to program one. We know how to find LLMs using SGD, but that doesn't tell us anything about how they work or about how to program one
> - [@RatOrthodox](https://twitter.com/RatOrthodox/status/1604877371114389505)
Most recent ML papers start with a long description of how Transformers have been incredibly successful in a huge variety of tasks. Capabilities are advacing rapidly, but our understanding of *how* Transformers do what they do is limited. There is lots of good mechanistic interpretability research being done[^mech_interp], and this project aims to contribute to that growing body of literature. In particular, we'll take to heart the [importance of search](https://www.alignmentforum.org/posts/FDjTgDcGPc7B98AES/searching-for-search-4) and attempt to make progress understanding how search happens in transformers.
[^mech_interp]: Neel Nanda has a [good list](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=eL6tFQqNwd4LbYlO1DVIen8K)
While at Conjecture, I worked on a project aiming to find search in a (fairly large) transformer trained on human chess games. This, for a variety of reasons, turned out to be a bit more ambitious than we expected. In fact, its not really clear what it means to "find search" in the first place -- this is still a pretty open question. I expect this project to mostly consist of:
- training a transformer on some toy[^toy_tasks] "searchy" task
- perform interpretability work on the trained model
- use these results to inform a new task or set of experiments to investigate
[^toy_tasks]: "toy" tasks here means tasks for which solutions are known, meaning both that data can be generated, and that the complexity of the task is in some sense lower -- this makes them significantly easier to study
## Experiments
This list is tentative, in the sense that the experiments we choose to do will be informed by the results we get. We'll definitely be doing the "basic maze-transformer" first, but we are afforded some freedom after that. If you have any suggestions, please feel free to comment or email me. Many of these experiments rely on the intermediate steps (often training a transformer to do some task) will be successful -- if these intermediate steps are not possible, then this is still a useful data point of something transformers are bad at.[^no_search_good] Most of these tasks focus around mazes, since they are a relatively simple task that is very "searchy" and easy to generate training data for, but I am not committed to only exploring maze or graph traversal type tasks.
[^no_search_good]: If the results of these experiments are that transformers are somehow [incapable of search in some way](https://www.alignmentforum.org/posts/FDjTgDcGPc7B98AES/searching-for-search-4#Learned_Search_in_Transformers), that is *really good news* since that buys us a bit of time.
- **Basic maze-transformer:** Generate synthetic sequential data consisting of: maze adjacency list, target location, start location, and the optimal path. Train an autoregressive transformer to predict the next token in a sequence, and evaluate its performance on finding optimal paths. Then, via exploring the attention scores, linear probes of residual stream, and anything else we think of, try to determine how the transformer is learning to solve mazes. I'm aware this is pretty vague, but I don't think there is a straightforward model for doing interpretability research other than "try things and see what works." Will the transformer learn Dijkstra's algorithm, or will it do something weird in a latent space? For this experiment, the data generation code is done, as are some very rudimentary attention-viewing experiments. There is much to try, including seeing how different tokenizations affect the transformer.
- the fundamentally interesting thing about this experiment is that if a transformer learns to solve this task, then it must actually be *solving the task by the time the first token is produced.* Suppose that from the starting point, there are two somewhat plausible choices for the path -- if the transformer reliably picks the correct one, it must be in some sense solving the entire maze by that point. Being able to develop a technique for extracting the future path (or even constraints on the future path) determined by the transformer would be incredibly useful.
- **Scaling laws for search:** Using the above model, and by training/evaluating a variety of model sizes on a variety of maze sizes, attempt to find scaling laws for this particular search task. If possible, see if these scaling laws in any way generalize to other "searchy" tasks, and see if there are any meaningful differences between loss scaling and accuracy scaling. This is of particular interest since it would be incredibly useful to have a way to detect only from the loss that a phase transition in search abilities has occurred.
- **Exploring inner alignment:** One of the main reasons we care about search in the context of alignment is because of the possibility of a misaligned mesa optimizer. To explore this, consider the following setup:
- A maze (or any other graph) where each node is "labelled" with a value sampled from some distribution $P$ over $\mathbb{R}^n$
- a "secret" target set $T$, with nontrivial intersection with the support of $P$
- training data consisting of node labels, an adjacency list, a starting node, and an optimal path to the the nearest node which is in $T$
Now, the problem with inner alignment is that a system whose goals appear to align with the outer objective on the test distribution will be revealed to have a different objective *when the distribution changes*. So, after we train a transformer (and verify that it does, in fact target $T$), we shift the distribution of node labels to some $P'$, and figure out what the inner objective $T'$ of the transformer is by seeing which nodes it ends up targeting. If $T'$ changes noticeably, this can become a good starting point for future inner alignment research.
- **Retargeting via prompting:** extending the above experiment, train a transformer with longer context window such that multiple (nodes, adjacency, path) tuples fit in the context window. Then, determine the sensitivity of the target set of the transformer when the target set in the prompt is changed. In particular, we still have a target set $T$ in the training data, but then pick some very different $T_*$, and use $T_*$ to generate paths in the prompt. What sort of set does the transformer target now -- the sets $T \cup T_*$, $T \cap T_*$, or something else entirely? Naively, it would appear better to have a transformer which maintains its alignment to a target set regardless of prompting, but this might not always be the case. Either way, getting some hard data on how susceptible transformers are to this sort of goal modification would be useful.
- **Anti-target set:** Same idea as in previous experiments, except take a page from neuroscience and create a set $A \in \mathbb{R}^n$ which denotes states to be avoided. So, paths in the training set will now be the shortest path from the start node, to a node in $T$, given that nodes in $A$ are off-limits. How robustly does the transformer avoid these states? Can we somehow extract its knowledge about what $A$ looks like? If we prompt it with paths which fail to avoid nodes in $A$, will it also stop avoiding set nodes? What about prompting with paths that avoid some $A_*$ instead?
## Results
In the best case scenario of this project, we would gain a mechanistic understanding of how search happens in transformers, be able to construct a transformer[^construct_transformer] (as opposed to find one via SGD) to perform search in a toy domain, and explicitly detect/set which states are being targeted by a transformer. I'm not under the impression that this is a particularly likely outcome, but I think there are quite a few fairly straightforward experiments which we can run which might tell us something interesting.
[^construct_transformer]: being able to reconstruct a circuit explicitly is the standard way of testing if you actually understand the circuit -- essentially, testing how good your compressed representation is.
# Output
The final report of this project, ideally to be posted on both the Alignment Forum and ArXiv, will hopefully contain some kind of explanation of transformer search in a limited context. If the project is less successful, and a proper model of the search process is not found, then at minimum a blog post detailing what work was done and why it failed will be produced.
Another key deliverable would be a github repository containing all code produced during the project. Potentially, portions should be packaged into a library if they might prove useful for other projects -- the dataset generation, interpretability tools created, etc. Additionally, if meaningful results are found, it will be crucial to package a minimal example of reproducing key results in a Colab/Jupyter notebook, in the style of [those produced by the transformer circuits thread](https://colab.research.google.com/github/anthropics/toy-models-of-superposition/blob/main/toy_models.ipynb).
# Risks and downsides
I don't think the risk of advancing capabilities from this research direction is significantly different from the risk from any other direction of mechanistic interpretability. Understanding how LLMs work always has the possibility of advancing capabilities, and mechanistic interpretability research which tells us exactly how to build circuits which perform a given function can reduce training time, improve model inference efficiency, or provide insight into new architectures. However, given the relative amount of research being done in AI interpretability versus capabilities, the chance of shortening timelines from a small project like this one is relatively minimal.
# Acknowledgements
I first got into this line of mechanistic interpretability research while at [Conjecture](https://www.conjecture.dev), so credit is due to everyone there. In particular, the team I worked with and continue to correspond with: [Kyle McDonell](https://www.alignmentforum.org/users/janus-1), [Laria](https://twitter.com/repligate) [Reynolds](https://generative.ink), and [Nicholas Kees](https://www.alignmentforum.org/users/nicholaskees). This project is primarily an extension of my work with them, and the post [Searching for Search](https://www.alignmentforum.org/posts/FDjTgDcGPc7B98AES/searching-for-search-4) is an expression of that research direction. Also, I'm thankful for the support of my PhD advisors [Cecilia Diniz Behn](https://inside.mines.edu/~cdinizbe/) and [Samy Wu Fung](https://swufung.github.io).
# Team
## Team Size / composition
I'm willing to take on up to 4 group members if the organizational duties will be my responsibility -- if someone is willing to help out with scheduling meetings and other administrative duties, a bigger group would be possible.
If at some point we see that it might be prudent to split up and work on separate experiments, that might be an option, although the current plan is to start with everyone working on the same experiment. I expect most of the differentiation of labour to come from either writing code, or from different interpretability approaches to the same experiment.
## Research Lead
Michael Ivanitskiy. [Github](https://github.com/mivanit). Contact me at [miv@knc.ai](mailto:miv@knc.ai).
I'm currently a 2nd year PhD student in applied mathematics, and my background was originally in mathematics and computational neuroscience. I interned under Kyle & Laria (janus) at Conjecture, mostly working on interpretability for transformers trained to play chess, with the goal of "finding search". I expect to be able to commit about 15-25 hours a week to this project, since this is basically just my PhD research. I'm open to any level of commitment from group members.
## Team Coordinator
Having a team member take on the role of team coordinator would probably be the most efficient choice.
## Skill Requirements
- proficient in python and a ML framework (project will use PyTorch, experience with JAX/TF is acceptable)
- willingness to use git
- decent understanding of transformer networks, sampling techniques, and attention heads
- basic familiarity with the inner alignment problem
- preferred: familiarity with existing transformer interpretability work. Neel Nanda's [list on the subject](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J) is a good place to start.
I envision this project being a good mix of conceptual and engineering work. I don't believe that conceptual work alone is sufficient, and I think our best bet for figuring out how search works in transformers is to try a bunch of different interpretability techniques and see what works. I'm a decent programmer, and I have a reasonably solid math background, but I only have so many ideas on how to look for internal search. I'm hoping to find team members who can help both with the actual implementation, and help come up with new ideas for what interpretability techniques to try.