Authors: Benjamin Qi, Ziqian Zhong
Acknowledgement: The work was done as part of the 2023 CBAI Winter ML Bootcamp (1/16 - 1/19).
Inspired by Neel's 200 Concrete Open Problems document, we investigated how tiny transformers sort lists of small integers. Unexpectedly, we found that bidirectional attention significantly outperformed causal attention, despite causal attention having at least as much expressive power as bidirectional attention in our setup. With mechanistic analysis, we are able to provide a complete interpretation of these models' behaviors.
Notation-wise, we closely follow A Mathematical Framework for Transformer Circuits, so check that out if you're unsure of something!
We considered one-layer attention-only transformers with a single head without any normalization, as implemented in the TransformerLens library. The accompanying code can be found here.
Attention Direction: Causal
Input: a list of \(10\) integers \(a_0,\dots,a_9\) where each \(a_i\) is in the range \([0,9]\).
Output: At each position \(p\), unnormalized logits \(\ell_0,\dots,\ell_9\) for every token in the vocabulary, where \(\ell_i\) corresponds to the probability that \(i\) is the token at position \(p+1\) in the sorted list.
Token Vocab: BOS, MOS, 0-9
Input Format: BOS a_0 a_1 ... a_9 MOS sorted(a)_0 sorted(a)_1 ... sorted(a)_9 (22 total)
Output Format: ? ? ? ... ? sorted(a)_0 sorted(a)_1 sorted(a)_2 ... ?
Here, \(?\) means that we disregard the model's output for that position.
Attention Direction: Causal
This task is the same as the one above, except with the additional condition that all \(a_i\) must be distinct, and a slightly larger range for the \(a_i\) (\([0,14]\)) to ensure that the task is somewhat nontrivial. One thing to keep in mind is, as there are only \(\binom{15}{10}=3003\) distinct sorted sequences, the model will likely see every sorted sequence at least once during training. However, the model structure and weight decay keeps the model from merely memorizing, as we will see later.
Attention Direction: Bidirectional
Input: a list of integers \(a_0,\dots,a_{n-1}\) where each \(a_i\) is in the range \([0,9]\) and \(1\le n\le 10\).
Output: At each position \(p\) of the input sequence, unnormalized logits \(\ell_0,\dots,\ell_9\) for every token in the vocabulary, where \(\ell_i\) corresponds to the probability that \(i\) is the token at position \(p\) in the sorted list.
Token Vocab: BOS, EOS, PAD, 0-9
Input Format: BOS a_0 a_1 ... a_{n-1} EOS PAD PAD ... (12 total)
Output Format: ? sorted(a)_0 sorted(a)_1 ... sorted(a)_{n-1} ? ...
The length of the input list is fixed in the first task. For the second task, we chose \(n\) uniformly at random from \([1,10]\).
For the first task, we initially tried generating all the inputs uniformly in the range \([0,9]\), but we found that the resulting causal model performed poorly on inputs with many occurrences of the same number, and especially inputs where all the numbers in the input are very close together (such as 0 0 0 0 0 1 1 1 1 1
). Specifically, near the middle of the sequence, it would often predict 0
instead of 1
or 1
instead of 0
. So instead, we used the following procedure (described in pseudocode) to generate "harder" inputs:
with probability 2/3:
while True:
p = (a float chosen uniformly at random from (0,1))
S = (a set where each integer from 0 to 9 independently selected with probability p)
if S is empty:
continue
else:
x = (a random integer from 0 to 9)
y = (a random integer from 0 to 9)
S = (all integers from min(x,y) to max(x,y))
return (a list of n randomly sampled elements of S)
All models have hidden dimension \(56\). We used cross-entropy loss and the Adam optimizer with learning rate going from \(10^{-3}\) to \(10^{-4}\) (decreased manually after loss stops decreasing) and weight decay \(10^{-4}\). We trained on \(5000\) to \(10000\) batches, where each batch consists of \(1024\) lists generated by the procedure described in the previous section.
Training for \(5000\) batches takes around \(200\) seconds on an A100 GPU. The bottleneck is generating the batches (which is not vectorized) instead of actually training the models, though.
For each model, we report the final mean cross-entropy loss, the fraction of \(4000\) randomly generated sequences sorted correctly, and the fraction of \(4000\) sequences generated by our pseudocode above sorted correctly.
Notice that our hidden dimension is only \(56\) and our models have only around \(15000\) parameters. The accuracy and loss do improve with a larger hidden dimension.
Average loss 0.0003 after 5000 batches and 100% accuracy in initial testing. However, after generating \(10^6\) more test cases using our procedure described above, we were able to find two sequences that our model failed to correctly sort.
Input List: [9, 0, 9, 9, 9, 9, 9, 9]
Output List: [0, 5, 9, 9, 9, 9, 9, 9]
Input List: [9, 9, 9, 9, 8, 9, 9, 9, 9, 8]
Output List: [8, 9, 9, 9, 9, 9, 9, 9, 9, 9]
For each of these input lists,
We could probably eliminate this incorrect behavior by training on more data consisting of exactly two distinct numbers.
Average loss 0.0006 after ~5000 batches. 100% accuracy (all possible cases up to reordering passed).
Average loss 0.015 after ~9000 batches. 98.6% accuracy on the random dataset and 96.8% accuracy on the hard dataset.
The following are some example cases where this model failed:
Input List: [1, 9, 9, 9, 1, 9, 1, 1, 9, 1]
Output List: [1, 1, 1, 1, 1, 8, 9, 9, 9, 9]
Input List: [8, 8, 8, 8, 8, 9, 9, 9, 9, 8]
Output List: [8, 8, 8, 8, 8, 8, 8, 9, 9, 9]
The first one seems to be only slightly off as the output logit of \(9\) is also high. The second example shows that the model has trouble deciding when to switch from outputting \(8\) to outputting \(9\).
We clip individual entropy losses at 0.02, so the model aims for good all-around performance rather than perfection. With that we got 97.1% accuracy on the hard dataset and 99.5% accuracy on the random dataset. For a larger hidden dimension (300), this trick managed to boost accuracy from ~99.9% to 100% (a few -> no counter-example observed). Not sure if that's a generally applicable trick or a matter of luck.
Remark: It was surprising that the loss for the causal model was higher despite being given a strictly easier task! Unlike the bidirectional model, it only needs to deal with fixed length sequences, and it can additionally use the tokens it has previously generated as information to predict the next one.
We now present two sorting algorithms that our models discovered.
min{>previous}
If a causal model is fed non-duplicated numbers, we have the following simple algorithm:
Suppose the previously generated token is t.
Use attention as a filter to consider only tokens > t in the input list.
Out of the remaining tokens, "output" the minimum by giving smaller tokens greater logits.
And this is exactly what is observed in the transformer of the no-duplicates version (experiment 2).
However, this method doesn't work when there are duplicated numbers, because the next token could be t
. It is possible to refine this method by incorporating position information, but the transformers didn't seem to be going this way.
The model could be implementing something like the following function:
def element_at_pos(s: set, t: int, pos: int):
if s.count_greater_than(t) >= len(s) - pos:
# ^ note: the above inequality cannot be strict
return s.get_min_greater_than(t)
# ^ this can easily be obtained assuming transformer only pays attention to tokens with value > t
# a token increases the logit corresponding to its value. tokens with lower values result in higher increases
else:
return t
In terms of logits (and argmax) we could do:
out_val =
BIG * (# >= val) + val if val > prev_val
BIG * (len(s) - pos) + val if val <= prev_val
which is possible with attention.
The following is the actual attention pattern for the sequence \([0,1,2,3,4,5,6,7,8,9]\) in experiment 2. This should give us some confidence that the simple algorithm above is exactly what is "implemented."
Next, we provide a full proof with mechanistic analysis.
a) Regardless of position, the Query-Key (QK) matrices are "upper-triangular."
The following are the QK matrices observed at different positions. "Query token" and "key token" are tokens at respective positions. Position in the title stands for query position, and we assumed the key token lies at the beginning of the unsorted sequences for these plots. The entries are standardized (a constant was added to each row such that the max became \(0\)) and clipped (at \(-99\)). Scaling and exp are not taken, and BOS and MOS are removed for simplicity.
# W_E: token embedding; W_pos: positional embedding
# W_Q, b_Q, W_K, b_K: weights of the Q and K
# these are pytorch codes - @ stands for matrix mult and T stands for transpose
((W_E+W_pos[11+pos])@W_Q[0,0]+b_Q[0,0])@((W_E+W_pos[1])@W_K[0,0]+b_K[0,0]).T
We can see that attention is only paid to key tokens greater than query tokens, and smaller key tokens are paid more attention (roughly). So the attention pattern is indeed doing the filtering.
For a fixed query token \(q\), the model seems to pay more attention to \(k=q\) as opposed to \(k\ge q+6\). This makes sense for \(k\ge q+7\) because the token following \(q\) in the sorted sequence must be at most \(q+6\) (a consequence of the number of distinct numbers not being much larger than the length of the sequence). As the token following \(q\) could potentially be equal to \(q+6\), it is unexpected then the model pays more attention to \(k=q\) than to \(k=q+6\). However, there are only \(9\) possible sorted sequences with a gap of size \(6\) (all of the form \([0,1,2,\dots,x,x+6,x+7,\dots,14]\)). Indeed, the input on which the model had the greatest loss was a permutation of \([0,6,7,\dots,14]\). The model manages to sort all sequences of this form most likely because the model saw all of those sequences in the training set.
b) Source sequence tokens "emit" themselves post-attention.
The following plot shows the logits contributed by each source token, assuming it is the only token attended to.
# W_E: token embedding
# W_V, b_V, W_O, b_O: weights of the V and O
# W_U, b_U: weights of the unembedding
((W_E[token]@model.W_V+model.b_V)@model.W_O+model.b_O).squeeze()@model.W_U+model.b_U
The largest values can be found along the main diagonal, meaning that every input token is just emphasizing itself.
Combining the two parts, we see that the model indeed implements our desired algorithm.
The positions of key tokens have minimal effect on the QK matrix.
The positions of query tokens do not affect the above property of QK matrix (this is more like an in-dataset bias which prefers larger values towards the end; the above QK matrix has big enough absolute values so the property still holds after these tweaks).
The positional embeddings emit the same information regardless of position (again, assuming all attention put on it).
Note: Although the model from experiment 1 (sorting fixed-length lists with duplicates) incorporates some ideas from this algorithm, it isn't doing quite the same thing. See the last section for more information.
It is possible to directly compute the \(k\)-th (0-indexed) element of the sorted list without first computing the \((k-1)\)-th element of the sorted list. Furthermore, this algorithm works just as well for variable-length lists as fixed-length lists.
Let the input list be \([a_0,a_1,\dots,a_{n-1}]\). Define the piecewise linear function \(f\) as \(f(x)\triangleq \sum_{i=0}^{n-1}\min(0,a_i-x)\). Note that this function is concave down. Next, define \(f_k(x)\triangleq f(x)+(k+0.5)x\). Then the \(k\)-th element of the sorted list is precisely \(\text{argmax}_x(f_k(x))\).
The diagram below shows that for \(a=[3,10]\), the \(k=1\)th element of the sorted list is \(\text{argmax}_x(f_1(x))=10\).
We claim that for the model from experiment 3, the logits from the \(k\)-th position of the output layer correspond to \([f_k(0), f_k(1), \dots, f_k(9)]\). Specifically…
a) The contribution of the token embedding for \(a_i\) towards the final logits is given by a piecewise linear function with peak at \(i\), corresponding to \(\min(0,a_i-x)\) (up to scaling, translation, and adding a multiple of \(x\)).
Similar how we analyzed the previous algorithm, we computed how each token affects the final logits if it is the only token attended to via the Output-Value (OV) circuit. We can see from the graph below that for each of \(d\in [0,9]\), the effect of token \(d\) on the final logits is given by a (roughly) piecewise linear function with peak at \(d\). Furthermore, all of these piecewise linear functions have the same slope. Paying more attention to the BOS token increases the logits corresponding to larger output digits. Little attention is paid to the EOS and PAD tokens, so we omit them from our plot.
with torch.inference_mode():
es=[]
for tt in range(13):
emb=model.W_E[tt]
if tt<10:
# emb=emb+model.W_pos[1]
pass
elif tt==10: # BOS
emb=emb+model.W_pos[0]
s=emb@model.W_V+model.b_V
s=s@model.W_O+model.b_O
s=s.squeeze()
s=s@model.W_U+model.b_U
es.append(s.tolist())
fig, ax = plt.subplots(1, 2, figsize=(12.8, 4.8))
for label in range(2):
inputs = "0-9"
if label == 1:
inputs = "BOS"
ax[label].set_title(f"OV Circuit ({inputs})")
ax[label].set_xlabel("Output Token")
ax[label].set_ylabel("Logits")
if label == 0:
for tt in range(10):
ax[label].plot(es[tt][:10],label=str(tt))
else:
for tt in range(10,11):
ax[label].plot(es[tt][:10],label=item_to_string(cfg, t.tensor([tt])))
ax[label].legend(title="Input Token")
plt.show()
b) For most destination positions, nearly equal attention is paid to all tokens with values in the range \([0,8]\).
We first show the attention pattern when the input list is \([0,1,2,3,4,5,6,7,8,9]\). For the destination tokens with values \(1\) through \(9\), we can see that the attention paid to the source tokens with values \(0\) through \(8\) is approximately uniform.
We verified that this pattern holds in general as follows: for each position, for each query token, we computed how much attention it paid to every key token. We found that for all positions and query tokens, nearly equal attention was paid to each of tokens 0-8, while the relative attention paid to the BOS token increased monotonically with position.
We conclude that the attention pattern for keys \([0,9]\) corresponds to summing all of \(\min(0,a_i-x)\) with equal weight, giving us \(f(x)\).
def plot_qk_for_out(pos: Optional[int] = None):
"""pos: position you are trying to predict"""
with t.inference_mode():
embeds = model.W_E.clone()
if pos is not None:
if cfg.bidirectional:
embeds += model.W_pos[1 + pos]
else:
embeds += model.W_pos[11 + pos] # 11 is MID
fig, ax = plt.subplots(1,cfg.num_heads,figsize=(9.6, 4.8), squeeze=False)
ax = ax[0]
key_embeds = model.W_E.clone()
key_embeds[cfg.distinct_nums] += model.W_pos[0] # BOS is always at position 0
for head in range(cfg.num_heads):
QK = (embeds @ model.W_Q[0, head] + model.b_Q[0, head]) @ (key_embeds @ model.W_K[0, head] + model.b_K[0, head]).T
title = f"QK For Query Position {pos}"
if pos is None:
title = f"QK Head {head}"
ax[head].set_title(title)
if not cfg.bidirectional:
QK = QK[:10, :11]
else:
QK = QK[:cfg.distinct_nums]
xticklabels = [str(i) for i in range(10)]
xticklabels.append("BOS")
xticklabels.append("EOS")
xticklabels.append("PAD")
heatmap(QK.cpu(), ax=ax[head], xticklabels=xticklabels, prec=2)
ax[head].set_ylabel("Query Token")
ax[head].set_xlabel("Key Token")
plt.tight_layout()
Caveats: The source token with value 9 is an exception; the attention paid to it is generally lower than to the other tokens. This is likely because the value of \(f(x)\) on \(x=0\dots 9\) remains unchanged if all terms involving \(a_i=9\) are removed from it (an alternative explanation is that in our model, embedding of BOS and token 9 are both increasing lines, so their coefficients are somewhat interchangable). The destination token corresponding to the zeroth position of the sorted sequence is also an exception: it appears that for this position alone the model determines what the element at that position in the sorted list should by paying more attention to source tokens with lower values. We can verify this conjecture by looking at the attention patterns for a random permutation of \([0,1,2,\dots,9]\):
This pattern would look the same as the one for \([0,1,2,\dots,9]\) if the rows and columns were sorted in increasing order of token value.
c) Attention on BOS accounts for the "\((k+0.5)x\)" addition.
From the plots above, we see that the attention paid to key token BOS increases along with the position \(k\). This is consistent with increasing the coefficient of \(x\) in \(f_k(x)\) as \(k\) increases.
It becomes even clearer when we plot the expected contribution ratio of BOS versus 0~8 for each position - we see a straight(-ish) line! Matches our "\((k+0.5)x\)" claim perfectly, up to scaling and translation.
It's the QK value difference between BOS and average over tokens 0~8, scaled by \(1/\sqrt{\text{head size}}\) and exponentiated.
def plot_qk_line():
with t.inference_mode():
avg_0_through_8 = []
avg_bos = []
avg_0 = []
avg_9 = []
single = [[] for _ in range(10)]
fig, ax = plt.subplots(1, 2, figsize=(9.6, 4.8))
for pos in range(1, 10):
embeds = model.W_E.clone()
if pos is not None:
if cfg.bidirectional:
embeds += model.W_pos[1 + pos]
else:
embeds += model.W_pos[11 + pos] # 11 is MID
key_embeds = model.W_E.clone()
key_embeds[cfg.distinct_nums] += model.W_pos[0] # BOS is always at position 0
for head in range(cfg.num_heads):
QK = (embeds @ model.W_Q[0, head] + model.b_Q[0, head]) @ (key_embeds @ model.W_K[0, head] + model.b_K[0, head]).T
avg_0_through_8.append(QK[:10, :9].mean().item())
for i in range(10):
single[i].append(QK[:10, i].mean().item())
avg_bos.append(QK[:10, 10].mean().item())
ax[0].plot(range(1, 10), avg_0_through_8, label="Average Over 0-8")
for i in range(9, 10):
ax[0].plot(range(1, 10), single[i], label=f"{i}")
ax[0].plot(range(1, 10), avg_bos, label="BOS")
ax[0].legend(title="Key Token")
ax[0].set_title("QK Values Averaged Over Q")
ax[0].set_xlabel("Position")
ax[0].set_ylabel("Average QK Value")
ax[1].set_title("Attention Paid to BOS Relative To 0-8")
ax[1].plot(range(1,10), np.exp((np.array(avg_bos)-np.array(avg_0_through_8))/np.sqrt(cfg.d_model)), label="exp((BOS - Avg(0-8))/sqrt(56))")
ax[1].set_xlabel("Position")
ax[1].set_ylabel("Attention Paid to BOS / Attention Paid to 0-8")
plt.tight_layout()
plot_qk_line()
Q: Why Does the Causal Model Perform So Much Worse?
A: We think that the model from the first experiment is using a mix of the refined version of algorithm 1 and algorithm 2, though we're not 100% sure.
The following are the QK matrices for query and key token pairs. We can see that for the upper-left subsquare (query and key tokens 0~4), attention is rather equally distributed, which is consistent with algorithm 2, but for the lower-right subsquare (query and key tokens 5~9), we are seeing some upper-triangular patterns as in algorithm 1.
Hypothesis: The bidirectional model does so much better than the causal model because it is only able to learn algorithm 2. On the other hand, the causal model has the ability to learn both algorithms 1 and 2. Algorithm 1 is easier to start learning, but is harder to get right when duplicates are allowed. Informally, the causal model sees both a slow and a fast route to the answer and ends up taking the slow route, while the bidirectional model takes the fast route to the answer because it is the only one accessible to it.
Evidence 1: We can also plot the change of the QK matrix during the training process. We can see the upper-triangular pattern emerges early in the training before the model tries to shift into a constant-attention pattern. There seems to be some success in the first rows but little could be done for the later rows.
Evidence 2: If we limit the causal model to ignore the previously generated tokens by keeping the same query vector regardless of the actual token in the Q-model, we got smaller loss (~0.0003) and 100% accuracy. This confirms that having access to previously generated tokens affects the causal model's performance.
model.reset_hooks()
def q_hook(value,hook):
return model.blocks[0].attn.W_Q[0,0].expand(value.shape)
model.add_hook(utils.get_act_name("q",0),q_hook)
\(\color{red}{\text{Warning for TransformerLens users:}}\) At the time of writing, TransformerLens has LRUCache enabled on model.W_Q
, model.b_Q
etc., so if you're still training the model, the results might be unexpected as weights from the previous versions might be returned as cached. Be very careful if you're trying to fiddle with the weights pre/during training!
Again we see a linear piecewise pattern for the contributions of tokens similar to the one in the bidirectional model, confirming the equivalence. A slight difference here is that the MOS also comes into play similar to BOS, as it now sits at a fixed position.