Benjamin Qi
    • Create new note
    • Create a note from template
      • Sharing URL Link copied
      • /edit
      • View mode
        • Edit mode
        • View mode
        • Book mode
        • Slide mode
        Edit mode View mode Book mode Slide mode
      • Customize slides
      • Note Permission
      • Read
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Write
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Engagement control Commenting, Suggest edit, Emoji Reply
      • Invitee
    • Publish Note

      Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

      Your note will be visible on your profile and discoverable by anyone.
      Your note is now live.
      This note is visible on your profile and discoverable online.
      Everyone on the web can find and read all notes of this public team.
      See published notes
      Unpublish note
      Please check the box to agree to the Community Guidelines.
      View profile
    • Commenting
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
      • Everyone
    • Suggest edit
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
    • Emoji Reply
    • Enable
    • Versions and GitHub Sync
    • Note settings
    • Engagement control
    • Transfer ownership
    • Delete this note
    • Save as template
    • Insert from template
    • Import from
      • Dropbox
      • Google Drive
      • Gist
      • Clipboard
    • Export to
      • Dropbox
      • Google Drive
      • Gist
    • Download
      • Markdown
      • HTML
      • Raw HTML
Menu Note settings Sharing URL Create Help
Create Create new note Create a note from template
Menu
Options
Versions and GitHub Sync Engagement control Transfer ownership Delete this note
Import from
Dropbox Google Drive Gist Clipboard
Export to
Dropbox Google Drive Gist
Download
Markdown HTML Raw HTML
Back
Sharing URL Link copied
/edit
View mode
  • Edit mode
  • View mode
  • Book mode
  • Slide mode
Edit mode View mode Book mode Slide mode
Customize slides
Note Permission
Read
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Write
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Engagement control Commenting, Suggest edit, Emoji Reply
Invitee
Publish Note

Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

Your note will be visible on your profile and discoverable by anyone.
Your note is now live.
This note is visible on your profile and discoverable online.
Everyone on the web can find and read all notes of this public team.
See published notes
Unpublish note
Please check the box to agree to the Community Guidelines.
View profile
Engagement control
Commenting
Permission
Disabled Forbidden Owners Signed-in users Everyone
Enable
Permission
  • Forbidden
  • Owners
  • Signed-in users
  • Everyone
Suggest edit
Permission
Disabled Forbidden Owners Signed-in users Everyone
Enable
Permission
  • Forbidden
  • Owners
  • Signed-in users
Emoji Reply
Enable
Import from Dropbox Google Drive Gist Clipboard
   owned this note    owned this note      
Published Linked with GitHub
Subscribed
  • Any changes
    Be notified of any changes
  • Mention me
    Be notified of mention me
  • Unsubscribe
Subscribe
# Sorting With Tiny Transformers Authors: Benjamin Qi, Ziqian Zhong *Acknowledgement:* The work was done as part of the 2023 [CBAI Winter ML Bootcamp](https://www.cbai.ai/winter-ml-bootcamp) (1/16 - 1/19). ## Overview Inspired by Neel's [200 Concrete Open Problems](https://docs.google.com/document/d/1WONBzNqfKIxERejrrPlQMyKqg7jSFW92x5UMXNrMdPo/edit#heading=h.ckqyehschys6) document, we investigated how tiny transformers sort lists of small integers. Unexpectedly, we found that bidirectional attention significantly outperformed causal attention, despite causal attention having at least as much expressive power as bidirectional attention in our setup. With mechanistic analysis, we are able to provide a complete interpretation of these models' behaviors. Notation-wise, we closely follow [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html), so check that out if you're unsure of something! ## Setups We considered one-layer attention-only transformers with a single head without any normalization, as implemented in the [TransformerLens](https://github.com/neelnanda-io/TransformerLens) library. The accompanying code can be found [here](https://github.com/bqi343/transformer-sorting/tree/main). ### Experiment 1: Sorting a Fixed-Length List Attention Direction: Causal Input: a list of $10$ integers $a_0,\dots,a_9$ where each $a_i$ is in the range $[0,9]$. Output: At each position $p$, unnormalized logits $\ell_0,\dots,\ell_9$ for every token in the vocabulary, where $\ell_i$ corresponds to the probability that $i$ is the token at position $p+1$ in the sorted list. ``` Token Vocab: BOS, MOS, 0-9 Input Format: BOS a_0 a_1 ... a_9 MOS sorted(a)_0 sorted(a)_1 ... sorted(a)_9 (22 total) Output Format: ? ? ? ... ? sorted(a)_0 sorted(a)_1 sorted(a)_2 ... ? ``` Here, $?$ means that we disregard the model's output for that position. ### Experiment 2: Sorting a Fixed-Length List (Easy) Attention Direction: Causal This task is the same as the one above, except with the additional condition that all $a_i$ must be distinct, and a slightly larger range for the $a_i$ ($[0,14]$) to ensure that the task is somewhat nontrivial. One thing to keep in mind is, as there are only $\binom{15}{10}=3003$ distinct sorted sequences, the model will likely see every sorted sequence at least once during training. However, the model structure and weight decay keeps the model from merely memorizing, as we will see later. ### Experiment 3: Sorting a Variable-Length List Attention Direction: Bidirectional Input: a list of integers $a_0,\dots,a_{n-1}$ where each $a_i$ is in the range $[0,9]$ and $1\le n\le 10$. Output: At each position $p$ of the input sequence, unnormalized logits $\ell_0,\dots,\ell_9$ for every token in the vocabulary, where $\ell_i$ corresponds to the probability that $i$ is the token at position $p$ in the sorted list. ``` Token Vocab: BOS, EOS, PAD, 0-9 Input Format: BOS a_0 a_1 ... a_{n-1} EOS PAD PAD ... (12 total) Output Format: ? sorted(a)_0 sorted(a)_1 ... sorted(a)_{n-1} ? ... ``` ## Data Generation ### Generating the List Length The length of the input list is fixed in the first task. For the second task, we chose $n$ uniformly at random from $[1,10]$. ### Generating the List Contents For the first task, we initially tried generating all the inputs uniformly in the range $[0,9]$, but we found that the resulting causal model performed poorly on inputs with many occurrences of the same number, and especially inputs where all the numbers in the input are very close together (such as `0 0 0 0 0 1 1 1 1 1`). Specifically, near the middle of the sequence, it would often predict `0` instead of `1` or `1` instead of `0`. So instead, we used the following procedure (described in pseudocode) to generate "harder" inputs: ```python= with probability 2/3: while True: p = (a float chosen uniformly at random from (0,1)) S = (a set where each integer from 0 to 9 independently selected with probability p) if S is empty: continue else: x = (a random integer from 0 to 9) y = (a random integer from 0 to 9) S = (all integers from min(x,y) to max(x,y)) return (a list of n randomly sampled elements of S) ``` ## Training Details All models have hidden dimension $56$. We used cross-entropy loss and the Adam optimizer with learning rate going from $10^{-3}$ to $10^{-4}$ (decreased manually after loss stops decreasing) and weight decay $10^{-4}$. We trained on $5000$ to $10000$ batches, where each batch consists of $1024$ lists generated by the procedure described in the previous section. Training for $5000$ batches takes around $200$ seconds on an A100 GPU. The bottleneck is generating the batches (which is not vectorized) instead of actually training the models, though. ## Results For each model, we report the final mean cross-entropy loss, the fraction of $4000$ randomly generated sequences sorted correctly, and the fraction of $4000$ sequences generated by our pseudocode above sorted correctly. Notice that our hidden dimension is only $56$ and our models have only around $15000$ parameters. The accuracy and loss do improve with a larger hidden dimension. ### Variable-Length, Bidirectional Average loss 0.0003 after 5000 batches and 100% accuracy in initial testing. However, after generating $10^6$ more test cases using our procedure described above, we were able to find *two* sequences that our model failed to correctly sort. ``` Input List: [9, 0, 9, 9, 9, 9, 9, 9] Output List: [0, 5, 9, 9, 9, 9, 9, 9] Input List: [9, 9, 9, 9, 8, 9, 9, 9, 9, 8] Output List: [8, 9, 9, 9, 9, 9, 9, 9, 9, 9] ``` ![](https://i.imgur.com/uAokFXM.png) ![](https://i.imgur.com/9GEFctM.png) For each of these input lists, - The model produces the incorrect output only at the $1$st position (0-indexed) - The logit corresponding to the correct output is only slightly lower than the logit corresponding to the actual output, meaning that the model isn't very far off from the correct output. We could probably eliminate this incorrect behavior by training on more data consisting of exactly two distinct numbers. ### Fixed-Length, Causal Model, No Duplicates Average loss 0.0006 after ~5000 batches. 100% accuracy (all possible cases up to reordering passed). ### Fixed-Length, Causal Model Average loss 0.015 after ~9000 batches. 98.6% accuracy on the random dataset and 96.8% accuracy on the hard dataset. The following are some example cases where this model failed: ``` Input List: [1, 9, 9, 9, 1, 9, 1, 1, 9, 1] Output List: [1, 1, 1, 1, 1, 8, 9, 9, 9, 9] Input List: [8, 8, 8, 8, 8, 9, 9, 9, 9, 8] Output List: [8, 8, 8, 8, 8, 8, 8, 9, 9, 9] ``` ![](https://i.imgur.com/MzIxopS.png) ![](https://i.imgur.com/N9Pv4RH.png) The first one seems to be only slightly off as the output logit of $9$ is also high. The second example shows that the model has trouble deciding when to switch from outputting $8$ to outputting $9$. :::spoiler We seem to be able to improve the accuracy with some trick though... We clip individual entropy losses at 0.02, so the model aims for good all-around performance rather than perfection. With that we got 97.1% accuracy on the hard dataset and 99.5% accuracy on the random dataset. For a larger hidden dimension (300), this trick managed to boost accuracy from ~99.9% to 100% (a few -> no counter-example observed). Not sure if that's a generally applicable trick or a matter of luck. :::   **Remark:** It was surprising that the loss for the causal model was higher despite being given a strictly easier task! Unlike the bidirectional model, it only needs to deal with fixed length sequences, and it can additionally use the tokens it has previously generated as information to predict the next one. ## Algorithms We now present two sorting algorithms that our models discovered. ### Algorithm 1 (Causal): `min{>previous}` If a causal model is fed non-duplicated numbers, we have the following simple algorithm: Suppose the previously generated token is t. Use attention as a filter to consider only tokens > t in the input list. Out of the remaining tokens, "output" the minimum by giving smaller tokens greater logits. And this is exactly what is observed in the transformer of the no-duplicates version (experiment 2). However, this method doesn't work when there are duplicated numbers, because the next token could be `t`. It is possible to refine this method by incorporating position information, but the transformers didn't seem to be going this way. :::spoiler Refinement idea we had The model could be implementing something like the following function: ```py def element_at_pos(s: set, t: int, pos: int): if s.count_greater_than(t) >= len(s) - pos: # ^ note: the above inequality cannot be strict return s.get_min_greater_than(t) # ^ this can easily be obtained assuming transformer only pays attention to tokens with value > t # a token increases the logit corresponding to its value. tokens with lower values result in higher increases else: return t ``` In terms of logits (and argmax) we could do: ``` out_val = BIG * (# >= val) + val if val > prev_val BIG * (len(s) - pos) + val if val <= prev_val ``` which is possible with attention. ::: #### Evidence The following is the actual attention pattern for the sequence $[0,1,2,3,4,5,6,7,8,9]$ in experiment 2. This should give us some confidence that the simple algorithm above is exactly what is "implemented." ![](https://i.imgur.com/afWVobi.png) Next, we provide a full proof with mechanistic analysis. **a) Regardless of position, the Query-Key (QK) matrices are "upper-triangular."** The following are the QK matrices observed at different positions. "Query token" and "key token" are tokens at respective positions. Position in the title stands for query position, and we assumed the key token lies at the beginning of the unsorted sequences for these plots. The entries are standardized (a constant was added to each row such that the max became $0$) and clipped (at $-99$). Scaling and exp are not taken, and BOS and MOS are removed for simplicity. :::spoiler Code for computing the patterns ```python # W_E: token embedding; W_pos: positional embedding # W_Q, b_Q, W_K, b_K: weights of the Q and K # these are pytorch codes - @ stands for matrix mult and T stands for transpose ((W_E+W_pos[11+pos])@W_Q[0,0]+b_Q[0,0])@((W_E+W_pos[1])@W_K[0,0]+b_K[0,0]).T ``` :::   ![](https://i.imgur.com/NornDPx.png) We can see that attention is only paid to key tokens greater than query tokens, and smaller key tokens are paid more attention (roughly). So the attention pattern is indeed doing the filtering. :::spoiler Additional notes on the attention patterns For a fixed query token $q$, the model seems to pay more attention to $k=q$ as opposed to $k\ge q+6$. This makes sense for $k\ge q+7$ because the token following $q$ in the sorted sequence must be at most $q+6$ (a consequence of the number of distinct numbers not being much larger than the length of the sequence). As the token following $q$ could potentially be equal to $q+6$, it is unexpected then the model pays more attention to $k=q$ than to $k=q+6$. However, there are only $9$ possible sorted sequences with a gap of size $6$ (all of the form $[0,1,2,\dots,x,x+6,x+7,\dots,14]$). Indeed, the input on which the model had the greatest loss was a permutation of $[0,6,7,\dots,14]$. The model manages to sort all sequences of this form most likely because the model saw all of those sequences in the training set. :::   **b) Source sequence tokens "emit" themselves post-attention.** The following plot shows the logits contributed by each source token, assuming it is the only token attended to. :::spoiler Code for computing the contributions ```python # W_E: token embedding # W_V, b_V, W_O, b_O: weights of the V and O # W_U, b_U: weights of the unembedding ((W_E[token]@model.W_V+model.b_V)@model.W_O+model.b_O).squeeze()@model.W_U+model.b_U ``` :::   ![](https://i.imgur.com/qy3NhNj.png) The largest values can be found along the main diagonal, meaning that every input token is just emphasizing itself. Combining the two parts, we see that the model indeed implements our desired algorithm. :::spoiler Some extra plots just in case you need more confidence The positions of key tokens have minimal effect on the QK matrix. ![](https://i.imgur.com/5F9gMgQ.png) The positions of query tokens do not affect the above property of QK matrix (this is more like an in-dataset bias which prefers larger values towards the end; the above QK matrix has big enough absolute values so the property still holds after these tweaks). ![](https://i.imgur.com/LtRbc4s.png) The positional embeddings emit the same information regardless of position (again, assuming all attention put on it). ![](https://i.imgur.com/M8gcTZl.png) :::   **Note:** Although the model from experiment 1 (sorting fixed-length lists with duplicates) incorporates some ideas from this algorithm, it isn't doing quite the same thing. See [the last section](#Causal-vs-Bidirectional-Model) for more information. ### Algorithm 2 (Bidirectional) It is possible to directly compute the $k$-th (0-indexed) element of the sorted list without first computing the $(k-1)$-th element of the sorted list. Furthermore, this algorithm works just as well for variable-length lists as fixed-length lists. Let the input list be $[a_0,a_1,\dots,a_{n-1}]$. Define the piecewise linear function $f$ as $f(x)\triangleq \sum_{i=0}^{n-1}\min(0,a_i-x)$. Note that this function is concave down. Next, define $f_k(x)\triangleq f(x)+(k+0.5)x$. Then the $k$-th element of the sorted list is precisely $\text{argmax}_x(f_k(x))$. The diagram below shows that for $a=[3,10]$, the $k=1$th element of the sorted list is $\text{argmax}_x(f_1(x))=10$. ![](https://i.imgur.com/hmcVpMD.png) #### Evidence <!-- Probabilities just look like ``` softmax(pos * (some fixed emb) + sum(input embeddings)) ``` --> We claim that for the model from experiment 3, the logits from the $k$-th position of the output layer correspond to $[f_k(0), f_k(1), \dots, f_k(9)]$. Specifically... **a) The contribution of the token embedding for $a_i$ towards the final logits is given by a piecewise linear function with peak at $i$, corresponding to $\min(0,a_i-x)$ (up to scaling, translation, and adding a multiple of $x$).** Similar how we analyzed the previous algorithm, we computed how each token affects the final logits if it is the only token attended to via the Output-Value (OV) circuit. We can see from the graph below that for each of $d\in [0,9]$, the effect of token $d$ on the final logits is given by a (roughly) piecewise linear function with peak at $d$. Furthermore, all of these piecewise linear functions have the same slope. Paying more attention to the BOS token increases the logits corresponding to larger output digits. Little attention is paid to the EOS and PAD tokens, so we omit them from our plot. ![](https://i.imgur.com/oUqMgoS.png) :::spoiler Plotting Code ```python= with torch.inference_mode(): es=[] for tt in range(13): emb=model.W_E[tt] if tt<10: # emb=emb+model.W_pos[1] pass elif tt==10: # BOS emb=emb+model.W_pos[0] s=emb@model.W_V+model.b_V s=s@model.W_O+model.b_O s=s.squeeze() s=s@model.W_U+model.b_U es.append(s.tolist()) fig, ax = plt.subplots(1, 2, figsize=(12.8, 4.8)) for label in range(2): inputs = "0-9" if label == 1: inputs = "BOS" ax[label].set_title(f"OV Circuit ({inputs})") ax[label].set_xlabel("Output Token") ax[label].set_ylabel("Logits") if label == 0: for tt in range(10): ax[label].plot(es[tt][:10],label=str(tt)) else: for tt in range(10,11): ax[label].plot(es[tt][:10],label=item_to_string(cfg, t.tensor([tt]))) ax[label].legend(title="Input Token") plt.show() ``` :::   **b) For most destination positions, nearly equal attention is paid to all tokens with values in the range $[0,8]$.** We first show the attention pattern when the input list is $[0,1,2,3,4,5,6,7,8,9]$. For the destination tokens with values $1$ through $9$, we can see that the attention paid to the source tokens with values $0$ through $8$ is approximately uniform. ![](https://i.imgur.com/dsNiiIW.png) We verified that this pattern holds in general as follows: for each position, for each query token, we computed how much attention it paid to every key token. We found that for all positions and query tokens, nearly equal attention was paid to each of tokens 0-8, while the relative attention paid to the BOS token increased monotonically with position. We conclude that the attention pattern for keys $[0,9]$ corresponds to summing all of $\min(0,a_i-x)$ with equal weight, giving us $f(x)$. ![](https://i.imgur.com/cPRhuJI.png) ![](https://i.imgur.com/5OKAmRq.png) ![](https://i.imgur.com/EsUjs0m.png) :::spoiler Plotting Code ```python= def plot_qk_for_out(pos: Optional[int] = None): """pos: position you are trying to predict""" with t.inference_mode(): embeds = model.W_E.clone() if pos is not None: if cfg.bidirectional: embeds += model.W_pos[1 + pos] else: embeds += model.W_pos[11 + pos] # 11 is MID fig, ax = plt.subplots(1,cfg.num_heads,figsize=(9.6, 4.8), squeeze=False) ax = ax[0] key_embeds = model.W_E.clone() key_embeds[cfg.distinct_nums] += model.W_pos[0] # BOS is always at position 0 for head in range(cfg.num_heads): QK = (embeds @ model.W_Q[0, head] + model.b_Q[0, head]) @ (key_embeds @ model.W_K[0, head] + model.b_K[0, head]).T title = f"QK For Query Position {pos}" if pos is None: title = f"QK Head {head}" ax[head].set_title(title) if not cfg.bidirectional: QK = QK[:10, :11] else: QK = QK[:cfg.distinct_nums] xticklabels = [str(i) for i in range(10)] xticklabels.append("BOS") xticklabels.append("EOS") xticklabels.append("PAD") heatmap(QK.cpu(), ax=ax[head], xticklabels=xticklabels, prec=2) ax[head].set_ylabel("Query Token") ax[head].set_xlabel("Key Token") plt.tight_layout() ``` :::   **Caveats:** The source token with value 9 is an exception; the attention paid to it is generally lower than to the other tokens. This is likely because the value of $f(x)$ on $x=0\dots 9$ remains unchanged if all terms involving $a_i=9$ are removed from it (an alternative explanation is that in our model, embedding of BOS and token 9 are both increasing lines, so their coefficients are somewhat interchangable). The destination token corresponding to the zeroth position of the sorted sequence is also an exception: it appears that for this position alone the model determines what the element at that position in the sorted list should by paying more attention to source tokens with lower values. We can verify this conjecture by looking at the attention patterns for a random permutation of $[0,1,2,\dots,9]$: ![](https://i.imgur.com/jZNpA3G.png) This pattern would look the same as the one for $[0,1,2,\dots,9]$ if the rows and columns were sorted in increasing order of token value.   **c) Attention on BOS accounts for the "$(k+0.5)x$" addition.** From the plots above, we see that the attention paid to key token BOS increases along with the position $k$. This is consistent with increasing the coefficient of $x$ in $f_k(x)$ as $k$ increases. It becomes even clearer when we plot the expected contribution ratio of BOS versus 0~8 for each position - we see a straight(-ish) line! Matches our "$(k+0.5)x$" claim perfectly, up to scaling and translation. ![](https://i.imgur.com/FQgh4Pz.png) :::spoiler How we calculated the ratio in the second graph It's the QK value difference between BOS and average over tokens 0~8, scaled by $1/\sqrt{\text{head size}}$ and exponentiated. ::: :::spoiler Plotting Code ```py def plot_qk_line(): with t.inference_mode(): avg_0_through_8 = [] avg_bos = [] avg_0 = [] avg_9 = [] single = [[] for _ in range(10)] fig, ax = plt.subplots(1, 2, figsize=(9.6, 4.8)) for pos in range(1, 10): embeds = model.W_E.clone() if pos is not None: if cfg.bidirectional: embeds += model.W_pos[1 + pos] else: embeds += model.W_pos[11 + pos] # 11 is MID key_embeds = model.W_E.clone() key_embeds[cfg.distinct_nums] += model.W_pos[0] # BOS is always at position 0 for head in range(cfg.num_heads): QK = (embeds @ model.W_Q[0, head] + model.b_Q[0, head]) @ (key_embeds @ model.W_K[0, head] + model.b_K[0, head]).T avg_0_through_8.append(QK[:10, :9].mean().item()) for i in range(10): single[i].append(QK[:10, i].mean().item()) avg_bos.append(QK[:10, 10].mean().item()) ax[0].plot(range(1, 10), avg_0_through_8, label="Average Over 0-8") for i in range(9, 10): ax[0].plot(range(1, 10), single[i], label=f"{i}") ax[0].plot(range(1, 10), avg_bos, label="BOS") ax[0].legend(title="Key Token") ax[0].set_title("QK Values Averaged Over Q") ax[0].set_xlabel("Position") ax[0].set_ylabel("Average QK Value") ax[1].set_title("Attention Paid to BOS Relative To 0-8") ax[1].plot(range(1,10), np.exp((np.array(avg_bos)-np.array(avg_0_through_8))/np.sqrt(cfg.d_model)), label="exp((BOS - Avg(0-8))/sqrt(56))") ax[1].set_xlabel("Position") ax[1].set_ylabel("Attention Paid to BOS / Attention Paid to 0-8") plt.tight_layout() plot_qk_line() ``` ::: ## Causal vs Bidirectional Model Q: Why Does the Causal Model Perform So Much Worse? A: We think that the model from the first experiment is using a mix of the refined version of algorithm 1 and algorithm 2, though we're not 100% sure. The following are the QK matrices for query and key token pairs. We can see that for the upper-left subsquare (query and key tokens 0~4), attention is rather equally distributed, which is consistent with algorithm 2, but for the lower-right subsquare (query and key tokens 5~9), we are seeing some upper-triangular patterns as in algorithm 1. ![](https://i.imgur.com/xFe8qRl.png) **Hypothesis:** The bidirectional model does so much better than the causal model because it is only able to learn algorithm 2. On the other hand, the causal model has the ability to learn both algorithms 1 and 2. Algorithm 1 is easier to start learning, but is harder to get right when duplicates are allowed. Informally, the causal model sees both a slow and a fast route to the answer and ends up taking the slow route, while the bidirectional model takes the fast route to the answer because it is the only one accessible to it. **Evidence 1:** We can also plot the change of the QK matrix during the training process. We can see the upper-triangular pattern emerges early in the training before the model tries to shift into a constant-attention pattern. There seems to be some success in the first rows but little could be done for the later rows. ![](https://i.imgur.com/rVDw377.gif) **Evidence 2:** If we limit the causal model to ignore the previously generated tokens by keeping the same query vector regardless of the actual token in the Q-model, we got smaller loss (~0.0003) and 100% accuracy. This confirms that having access to previously generated tokens affects the causal model's performance. :::spoiler Implementation with TransformerLens library ```python model.reset_hooks() def q_hook(value,hook): return model.blocks[0].attn.W_Q[0,0].expand(value.shape) model.add_hook(utils.get_act_name("q",0),q_hook) ``` :::   $\color{red}{\text{Warning for TransformerLens users:}}$ At the time of writing, TransformerLens has LRUCache enabled on `model.W_Q`, `model.b_Q` etc., so if you're still training the model, the results might be unexpected as weights from the previous versions might be returned as cached. Be very careful if you're trying to fiddle with the weights pre/during training! Again we see a linear piecewise pattern for the contributions of tokens similar to the one in the bidirectional model, confirming the equivalence. A slight difference here is that the MOS also comes into play similar to BOS, as it now sits at a fixed position. ![](https://i.imgur.com/ZZzzeQW.png)

Import from clipboard

Paste your markdown or webpage here...

Advanced permission required

Your current role can only read. Ask the system administrator to acquire write and comment permission.

This team is disabled

Sorry, this team is disabled. You can't edit this note.

This note is locked

Sorry, only owner can edit this note.

Reach the limit

Sorry, you've reached the max length this note can be.
Please reduce the content or divide it to more notes, thank you!

Import from Gist

Import from Snippet

or

Export to Snippet

Are you sure?

Do you really want to delete this note?
All users will lose their connection.

Create a note from template

Create a note from template

Oops...
This template has been removed or transferred.
Upgrade
All
  • All
  • Team
No template.

Create a template

Upgrade

Delete template

Do you really want to delete this template?
Turn this template into a regular note and keep its content, versions, and comments.

This page need refresh

You have an incompatible client version.
Refresh to update.
New version available!
See releases notes here
Refresh to enjoy new features.
Your user state has changed.
Refresh to load new user state.

Sign in

Forgot password

or

By clicking below, you agree to our terms of service.

Sign in via Facebook Sign in via Twitter Sign in via GitHub Sign in via Dropbox Sign in with Wallet
Wallet ( )
Connect another wallet

New to HackMD? Sign up

Help

  • English
  • 中文
  • Français
  • Deutsch
  • 日本語
  • Español
  • Català
  • Ελληνικά
  • Português
  • italiano
  • Türkçe
  • Русский
  • Nederlands
  • hrvatski jezik
  • język polski
  • Українська
  • हिन्दी
  • svenska
  • Esperanto
  • dansk

Documents

Help & Tutorial

How to use Book mode

Slide Example

API Docs

Edit in VSCode

Install browser extension

Contacts

Feedback

Discord

Send us email

Resources

Releases

Pricing

Blog

Policy

Terms

Privacy

Cheatsheet

Syntax Example Reference
# Header Header 基本排版
- Unordered List
  • Unordered List
1. Ordered List
  1. Ordered List
- [ ] Todo List
  • Todo List
> Blockquote
Blockquote
**Bold font** Bold font
*Italics font* Italics font
~~Strikethrough~~ Strikethrough
19^th^ 19th
H~2~O H2O
++Inserted text++ Inserted text
==Marked text== Marked text
[link text](https:// "title") Link
![image alt](https:// "title") Image
`Code` Code 在筆記中貼入程式碼
```javascript
var i = 0;
```
var i = 0;
:smile: :smile: Emoji list
{%youtube youtube_id %} Externals
$L^aT_eX$ LaTeX
:::info
This is a alert area.
:::

This is a alert area.

Versions and GitHub Sync
Get Full History Access

  • Edit version name
  • Delete

revision author avatar     named on  

More Less

Note content is identical to the latest version.
Compare
    Choose a version
    No search result
    Version not found
Sign in to link this note to GitHub
Learn more
This note is not linked with GitHub
 

Feedback

Submission failed, please try again

Thanks for your support.

On a scale of 0-10, how likely is it that you would recommend HackMD to your friends, family or business associates?

Please give us some advice and help us improve HackMD.

 

Thanks for your feedback

Remove version name

Do you want to remove this version name and description?

Transfer ownership

Transfer to
    Warning: is a public team. If you transfer note to this team, everyone on the web can find and read this note.

      Link with GitHub

      Please authorize HackMD on GitHub
      • Please sign in to GitHub and install the HackMD app on your GitHub repo.
      • HackMD links with GitHub through a GitHub App. You can choose which repo to install our App.
      Learn more  Sign in to GitHub

      Push the note to GitHub Push to GitHub Pull a file from GitHub

        Authorize again
       

      Choose which file to push to

      Select repo
      Refresh Authorize more repos
      Select branch
      Select file
      Select branch
      Choose version(s) to push
      • Save a new version and push
      • Choose from existing versions
      Include title and tags
      Available push count

      Pull from GitHub

       
      File from GitHub
      File from HackMD

      GitHub Link Settings

      File linked

      Linked by
      File path
      Last synced branch
      Available push count

      Danger Zone

      Unlink
      You will no longer receive notification when GitHub file changes after unlink.

      Syncing

      Push failed

      Push successfully