# Reviewer 1Hg8
**Response:**
Thank you for your review of our work. Here we address your questions:
> 1. What do you believe is the specific benefit of studying the learning behavior of Transformer models within controlled mathematical definition tasks for understanding their performance in natural language processing tasks?
We believe that abstracting away certain complexities of language and coming up with a mathematical model that captures the essential structures can tell us many things about how transformers learn and what they learn! Markov chains are a very natural model to consider and, indeed, Markov chains have a long history of modeling natural language - see the landmark paper of Claude Shannon (1948). For instance, as we mention in our introduction, in our paper alone we find that:
1. Transformers trained on Markov Chains exhibit in-context learning abilities and they do so by forming specialized statistical induction heads, similar to what has been observed in Natural Language experiments - see for instance [1, 2].
2. Transformers first learn a suboptimal in-context learning solution, before eventually learning the correct one. Since the problem is well defined mathematically, we can characterise in this case what this suboptimal solution is - and we do so! Our analysis (both mathematical and experimental) demonstrates that this suboptimal solution slows down optimization and indeed we find ways to speed it up based on this knowledge.
We do believe that a mathematical understanding of learning with Transformers in such simple tasks can guide engineering choices in practice - for instance, we observe that relative positional embeddings are well suited for this task.
<!-- , and perhaps reveal fundamental limitations of Transformers. -->
> What specific benefits do your findings offer towards solving downstream tasks?
Our findings do not yet provide engineering guidelines for natural language tasks, but we believe they could serve as inspiration for researchers that do applied research - for instance, in our problem setup we identify that curriculum learning helps with optimization, so this could potentially be an interesting idea for further exploration.
Our work is meant to be a contribution to the pursuit of scientific understanding of deep learning. The fact that it is a theoretical and conceptual paper on how transformers learn is not a limitation per se, evidenced by the large number of papers studying similar questions and the ever-growing interest of the community in understanding how neural networks learn features. Thus, we kindly ask you to consider raising your score, especially if you do not identify any technical limitations or you do not have any outstanding concerns. Thank you!
1. Olsson, C., Elhage, N., Nanda, N., Joseph, N., DasSarma, N., Henighan, T., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Johnston, S., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., Amodei, D., Brown, T., Clark, J., Kaplan, J., McCandlish, S., and Olah, C. In-context learn- ing and induction heads. Transformer Circuits Thread, 2022.
2. Ekin Akyürek, Bailin Wang, Yoon Kim, Jacob Andreas. In-Context Language Learning: Architectures and Algorithms, 2024.
# Reviewer tLXW
**Response:**
Thank you for your review of our work. We address your comments:
> theoretical analysis applies to a simplified transformer (e.g., no MLPs), which is fully defined only in the appendix
Thank you for the feedback. We will make sure to state in the introduction that our study focuses on attention-only transformers. Note, however, that mathematically analyzing the optimization of a full transformer in generality is a challenging open problem. For that reason, in our theoretical analysis, we focus on simplified linear transformers trained with the margin loss - a setting that empirically replicates what we observe in the full transformer. You can also take a look at our reply to reviewer dHUk, where we show experiments with more complicated transformers (including MLP modules) and confirm that the same interesting phenomena are still present.
> the study focuses on very simplified unigram and bigram strategies -- this is very different from the pretraining data that LLMs are exposed to
Indeed! However, even in this toy data model, we identify many phenomena that puzzle researchers that study natural language models, like induction heads [1] and phase transitions [2, 3], and, furthermore, we are able to understand them (see for instance the simplicity bias of learning unigrams first) and provide explanations/solutions. We believe that analyzing a simple task allows us to understand things that would be challenging to do otherwise, laying the groundwork for understanding realistic settings.
> results apply to GD using the population-level gradient, not SGD. While this is not a problem for the first step (where the result provides the expected update), it is not so obvious what this means for results regarding the second step of optimization. [perhaps the authors can explain this in their rebuttal]
This is a standard technique in the analysis of neural networks and it has been adopted in a plethora of theoretical works in the field - see for instance [4, 5]. Also, experimentally the batch size (which controls how close or far are we to SGD) does not seem to affect our results. This is why we believe our analysis provides insights on our experimental findings.
> Line 273 second column refers to Section 3.3, but is inside that section.
Good catch, fixed. We meant to say "Section 2.3". Thank you!
> The model is only introduced in the appendix. I think it would be much more transparent if the model were properly defined in the main paper.
We introduce the minimal model in Section 3.2 - eq. 6 provides the definition. We provide further details about this model in the Appendix and eq. 9, where we fully analyze the optimization trajectory of the model. As mentioned above, we will make sure to mention in the introduction that our results hold for attention only transformers, as per your recommendation. We will make sure to add more details on the description of the minimal model, as per the recommendation of the reviewer dHUk, too. Thank you!
> Line 355, what does "2" refer to?
It refers to Figure 2, thank you for the catch! We will add the word "Figure" before. Figure 2 provides evidence that the 2nd layer in the transformer forms earlier than the 1st one.
> As the paper makes claims about ICL, it would be beneficial to be a bit more transparent about the ways in which the studied setup is highly simplified.
Thank you! We will be more transparent for our analysis involving the minimal model.
We hope that our answers clarified your questions. If you feel that our suggested changes will further improve the presentation of our work and you do not have any outstanding concerns, please consider raising your score. Thank you!
**Refs**
1. Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021. https://transformer-circuits.pub/2021/framework/index.html.
2. Angelica Chen, Ravid Shwartz-Ziv, Kyunghyun Cho, Matthew L. Leavitt, Naomi Saphra. Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs, 2023
3. Jesse Hoogland, George Wang, Matthew Farrugia-Roberts, Liam Carroll, Susan Wei, Daniel Murfet. The Developmental Landscape of In-Context Learning, 2024
4. Boaz Barak, Benjamin L. Edelman, Surbhi Goel, Sham Kakade, Eran Malach, Cyril Zhang. Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit, 2023.
5. How Transformers Learn Causal Structure with Gradient Descent. Eshaan Nichani, Alex Damian, Jason D. Lee, 2024
# Reviewer dHUk
**Response:**
Thank you very much for the comprehensive review of our work and your detailed comments. We really appreciate it and it has significantly helped to improve the presentation of our results. We address your points:
> Needs to be more clear about their claims about transformers: they only study attention-only shallow transformers. That needs to be made more clear in the abstract/introduction.
Thank you for the suggestion. We will make sure to mention this in the introduction in various places - for example, line 60-61 "we train a transformer on sequences of tokens" -> "we train a two layer attention-only transformer on sequences of tokens".
> Proposition 2.1: It occurs without introduction or context and it confused me on a first reading that the transformer in question was somehow assumed to be of the given form. Please clarify.
Thank you for pointing this out. We will make sure to add more context and better integrate the lemma with the section. We will improve the notation, and add an intuitive explanation to what each layer does before the details. We will include a clearer and more formal proof in the section.
We also thank you for your careful proof reading, and we apologize for the lack of clarity. You are correct about some errors in the dimensions of the construction. Reviewing these definitions, we will make a number of changes to both improve clarity and better follow conventions. For completeness, here we describe the construction with the changes.
For simplicity's sake, we can assume that the transformer has the same hidden dimension as the embedding dimension (none of our claims depend on this assumption, but it makes notation cleaner). This means that $W_Q, W_K, W_V\in \mathbb{R}^{d\times d}, v\in \mathbb{R}^{t\times d}$, and $P\in \mathbb{R}^{d\times k}.$ The transformer definition had a slight typo and should be as follows
$$TF = P\circ (Attn_n+I) \dots \circ (Attn_1+I)$$
And to better follow the conventions for relative position encodings, we will make a change to the attention (though note that this change is only notational, and is symmetric with the original definition)
$$Attn(z) = \text{softmax}(\text{mask}(A))z W_V$$
$$A_{i,j} = \frac{(z_i W_Q)(z_j W_K+v_{i-j+1})^\top}{\sqrt{d}}.$$
(technically, if $i<j$, this is undefined, but since $A$ is masked we don't need to consider that case)
Briefly restating the construction with proper dimension. Set the internal dimension $d=3k$. Embed each tokens using one hot encodings of length $3k$, that is $e_{i,j}=1$ if and only if $x_i=j$.
$$v^{(1)}=\begin{pmatrix}\mathbf{0}_{1\times d}\\
\mathbf{1}_{1\times d} \\
\mathbf{0}_{t-2\times d}\end{pmatrix}\quad W_Q^{(1)}=\begin{pmatrix}
cI^{k\times k} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}
\end{pmatrix} \quad W_K^{(1)}=\mathbf{0}\quad W_V^{(1)}=\begin{pmatrix}
\mathbf{0} &I^{k\times k}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}
\end{pmatrix}$$
$$v^{(2)}=\mathbf{0}\quad W_Q^{(2)}=\begin{pmatrix}
cI^{k\times k} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}
\end{pmatrix} \quad W_K^{(2)}=\begin{pmatrix}
\mathbf{0} &\mathbf{0} &\mathbf{0}\\
c I^{k\times k} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}
\end{pmatrix}\quad W_V^{(2)}=\begin{pmatrix}
\mathbf{0} &\mathbf{0}&I^{k\times k}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}\\
\mathbf{0} &\mathbf{0}&\mathbf{0}
\end{pmatrix}$$
$$P=\begin{pmatrix}
\mathbf{0} \\
\mathbf{0}\\
I^{k\times k}
\end{pmatrix}$$
In the first layer, for large enough $c$, $\text{softmax}(\text{mask}(A^{(1)}))_{i,i-1}\approx 1$ and for all $j\neq i-1$, $\text{softmax}(\text{mask}(A^{(1)}))_{i,j}\approx 0$. The model attends to the previous token, and $W_V^{(1)}$ is chosen such that the embedding of the previous token are appended after the emebedding of the current token.
In the second layer, for large enough $c$, $\text{softmax}(\text{mask}(A^{(2)}))_{i,j}\approx \mathbb{1}[x_{j-1}=x_i].$ Because the token embeddings are one-hot, the sum of all the tokens being attended to is a vector counting the empirical bigram statistics for what tokens have followed the current token. $W_V^{(2)}$ and $P$ are chosen to the make the output of the model these statitics.
The construction can be viewed as a special case of disentangled transformers (from [1]) (with relative positional encodings). We use embedding dimension three times the input dimension so we can append the output of each layer to the previous, instead of adding them. This allows for a relatively simple and human understandable construction.
> Minimal model: Likewise, it would be good to explain better the minimal model. [...] Too much is being asked of the reader to parse the notation.
&
> I'm also a bit unclear about the relationship between the minimal model and the attention-only transformer. The former has only one layer but a claim is made about how it related to the two layer attention-only transformer. What is the claim exactly? It is a bit obscured by the technicality of Lemma 3.1 and it's long proof in the appendix.
Let us clarify the connection between the minimal model and a two-layer attention only transformer: In words, the minimal model is a linear, attention only, causal transformer, with a first layer with a single learnable matrix (instead of two) with a specific parameterization (that encodes the positional information) and an identity value matrix and a second layer with just one learnable matrix that parameterizes the product of the key and the query matrix. If $TF_{i}(E) = \text{mask}(EQ_iK_i^TE^T)EV_i^T$ is a linear attention layer (with one head), then by setting $Q_1K_1^T = E^{-1}ME^{-T}, V_1 = I$ ($M$ being a matrix with a specifix, learnable, structure - see below eq. (6) in the paper) and $Q_2K_2^T = W_K, V_2=I$, we obtain a two-layer transformer $TF_2(TF_1(E)) = \text{mask}(ME W_k (ME)^T)ME$. Dropping the first and the last occurence of the matrix $M$ (for simplicity), we obtain the minimal model of eq. (6): $f(E) = \text{mask}(E W_k (ME)^T)E$. One can observe that this minimal model needs to learn the basic components of the optimal solution, as per our construction (Proposition 2.1). We see that our initial exposition in lines 231-233 was insufficient and we will add this discussion in the revised version. Dropping the first and the last occurence of the matrix $M$ can also be viewed as the second layer operating directly on the input of the model, which could be implemented with residual connections or adopting the viewpoint of a disentagled transformer [1]. We will mention this connection in our revision. Thank you once again for helping us improve our presentation!
> Figure 2: Unclear what the attention pattern being displayed is mathematically (the blue lines). Is it a thresholded heat map (i.e. how blue a line is is the value of the attention and below a certain threshold, the line is invisible / not drawn?).
There is no thresholding, but we do have both the thickness and alpha of the line depend linearly on (normalized) attention. We normalize such that each attending token has the same max value among tokens it attends to. We will add a short explanation of this in the caption of the figure.
> Figure 4: I'm confused by the caption and takeaway. "When there is not much signal from unigrams, learning progresses faster without long plateaus". But it is the uninformative unigrams that have the plateau. Perhaps I am not understanding something. The Appendix B is not a helpful description: a bunch of disjointed formulas with no explanation. This should be rewritten for clarity.
Thank you for pointing this out. There is a typo there. Let us try to clarify.
The purple lines correspond to distributions where the unigram strategy is (close to) Bayes optimal, and this is why models trained on these distributions sometimes never improve past the performance of the unigram algorithm on the full distribution (purple lines plateau). On the other hand, the yellow lines correspond to distributions where the unigram algorithm is (close to) as bad as uniform (random) guessing, and the models quickly improve beyond this unigram strategy.
The code for the distribution is as follows, parametrized by $p$
```a = np.random.uniform()
mu = a+p*(1-2*a)
factor = 0.2
b = mu + np.random.uniform(-factor, factor)
b = min(1,max(0,b))
if np.random.rand() > .5:
transition_matrix = np.array([[a,1-a],[b, 1-b]])
else:
transition_matrix = np.array([[b,1-b],[a, 1-a]])
```
We rewrote the paragraph that describes these distributions in the Appendix. Thank you!
Furthermore, we agree that this plot and distribution are confusing, and have ran a new experiment which we believe backs up our claim without requiring complicated new distributions. Consider the following distributions. The *doubly stochastic distribution* is one where the unigram algorithm is as bad as random (which happens to be when the transition matrix is doubly stochastic), and the *unigram distribution* is one where the unigram algorithm is Bayes optimal (which is when every row of the transition matrix is the same). We train one model on the doubly stochastic distribution, and one model on a mixture of 75% doubly stochastic distribution and 25% unigram distribution. We then compare their loss on the full distribution, and to be fair, we have the x-axis be the number of doubly-stochastic samples seen. Indeed, the first model converges faster than the second. This shows that extra training samples from the unigrams distribution slowed down the learning of the model, even though it recieved the same number of samples from the doubly stochastic distribution.
See plot here: https://anonymous.4open.science/r/icmlrebuttals2024-1304/4symb_DS_mixture.png. The plot includes 95% confidence intervals, estimated based on 10 runs with different random seeds for each experimental condition. This new experiment also allows us to consider more than $k=2$ tokens. For full transparency, it is true that the second model converged slightly lower loss in the full distribution, this being due mostly to having lower loss on the unigrams distribution.
> Even/odd positional encoding pattern: I'm unable to discern this from various graphs (e.g. middle column of Figure 3). I see a continuous band of blue. Are there supposed to be gaps or some kind of alternating teeth pattern?
Figure 3 (middle) and Figure 5 (right) display the (relative) positional embeddings in the transformer and the minimal model, respectively. They show the values of a $T$-dimensional vector, where $T$ is the length of the sequence. As can be seen in both Figures (more clearly in Figure 5), the values in positions 1, 3, 5, ... grow larger than the values in the even positions (this happens during the phase transition - $t=92$ in the transformer for instance). It is true that this bias shrinks for later positions; we will add this to the caption. It is also demonstrated more clearly in Figure 11 in the Appendix in the case of $k=2$ (2 state Markov chain). This behavior is predicted from eq. (48) in the Appendix. We will add a few more lines of explanations in the main text, where we refer to these figures and/or in the caption of Figure 3 (& add the reference to Figure 11) and add a zoomed in version where the phenomenon can be clearly seen. Thanks!
> Followup experiments:
For transformers with MLPs: Do similar learning dynamics (e.g. the phase transitions) occur? Does convergence to the Bayes Optimal classifier still hold and at the same speed? Any interesting thing to say about attention patterns?
Thank you for the suggestion! We ran experiments on this, and the high-level picture remains the same: the models exhibit the same phase transitions, learning is seperated into phases, and the models form similar attention patterns, albeit reaching slightly lower loss values and converging faster.
You can find plots for two-layer attention-MLP transformers trained on the original setting (ICL-Markov chains with 3 states) here:
Loss: https://anonymous.4open.science/r/icmlrebuttals2024-1304/mlploss.png
Distance from strategies: https://anonymous.4open.science/r/icmlrebuttals2024-1304/mlpphases.png
Attention patterns: https://anonymous.4open.science/r/icmlrebuttals2024-1304/mlpattention.png
Once again, we would like to thank you for your very valuable suggestions---they are going to help improving our paper. We hope we addressed your concerns sufficiently enough for you to raise your score.
Refs:
1. Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova DasSarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam McCandlish, and Chris Olah. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021. https://transformer-circuits.pub/2021/framework/index.html.