# Reproduction Blog: *Learning Transformer Programs*
This blog was written by TU Delft students of **group 13** as part of the course Deep Learning (CS4240).
Authors:
- Dan Sochirca (5295580), D.Sochirca@student.tudelft.nl - ablation 3 - shift task
- Venelina Pocheva (5093570), V.A.Pocheva@student.tudelft.nl - ablation 3 - run-Length Encoding Task
- Nadine Kuo (5204895), h.n.kuo@student.tudelft.nl - Introduction, ablation 1
- Joshua Azimullah (5054354), j.r.azimullah@student.tudelft.nl - Transformer Architecture, ablation 2
## Introduction
This blog post is dedicated to our reproducibility project centered around the paper: *"Learning Transformer Programs"* [^LearningTrans]. The authors of this work aim to present a methodology for training Transformers that are mechanistically interpretable by design. They show how a modified Transformer can be designed that - when trained using gradient-based optimization techniques - can be reverse-engineered into discrete, human-readable programs known as *Transformer Programs*.
In order to do so, they leverage RASP [^RASP], a conceptual framework (i.e. programming language) for bridging the gap between between Transformer components and code that humans can inspect.

To demonstrate that these Transformer Programs are easier to interpret whilst maintaining performance levels comparable to standard Transformers, the authors learn Transformer Programs for several problems. These include an in-context learning task, algorithmic RASP problems (e.g. sorting, string reversal etc.) and NLP tasks such as named entity recognition.
In this project, our aim is to reproduce and ablate the experimental results presented in Table 1 of the paper [^LearningTrans]. This includes the RASP tasks: "Reverse", "Histogram", "Double hist.", "Sort" and "Most-Freq".
Specifically, the ablations we introduce are:
1. Varying vocabulary sizes and max. sequence lengths fed to the model
2. Expanding the pool of synthesized training data in an attempt to offset any potential accuracy decline caused by ablation 1.
3. Introducing two new RASP tasks: "Swap case" and "Run-time Length".
Below we provide background related to the original Transformer architecture and RASP language, after which we dive into our experimental setup and reproducibility results achieved.
## Background
### Transformer Architecture
The transformer model operates on a sequence of tokens $w = w_1, . . . , w_N ā V$, mapping these to a set of N probabilities, predicting the next word. The architecture comprises the following stages:
1. **Word Embeddings**: Each token is mapped into a d-dimensional space, $\mathbb{R}^{d}$. Positional embeddings are added to these word embeddings to incorporate the token's sequence position.
2. **Multihead Attention**: The process of multihead attention (MHA) involves H heads performing attention mechanisms simultaneously. For each word:
- **Query and Key Generation**: Word embeddings are transformed using matrices $W_Q \in \mathbb{R}^{d \times dh}$ and $W_K \in \mathbb{R}^{d \times dh}$ to generate query and key vectors, where $dh$ is the attention dimension.
- **Attention Calculation**: The dot product of query and key vectors quantifies the influence or 'attention' of word $w_j$ on word $w_i$. Following this, a softmax function normalizes these values across each column.
- **Value Weighting and Aggregation**: Each normalized attention score scales the corresponding value vector, derived by multiplying word embeddings with $W_V \in \mathbb{R}^{d \times dh}$. Results are accumulated to adjust the original word embedding.
Since each single head of attention is independent of every other head, all heads of the multihead attention are performed at once using a concatenated matrix approach, followed by re-projection via $W_O \in \mathbb{R}^{dh \times d}$. The full multihead attention step can be formulized as:
$MHA(x) = \sum_{h=1}^{N} softmax(\frac{xW^h_Q(xW^h_K)^\intercal}{\sqrt{d_k}})xW^h_VW^h_O$
or written for each word $w_i$:
$z_i = \sum_j A_{ij} \cdot v_j = \sum_j softmax(\frac{q_i \hspace{1 mm} k_j}{\sqrt{d_k}}) \cdot v_j$
3. **Feedforward Neural Network**: Post attention, word embeddings pass through a multi-layer perceptron (MLP), which introduces non-linearity and further transforms the embeddings. Each layer $i$ of the transformer is formulized as:
$x_i = x_{iā1} + MLP_i(x_{iā1} + MHA_i(x_{iā1}))$
4. **Repeat** This combined step is performed for $n$ layers, where each word embedding is more and more transformed by the meaning of all other word embeddings and their position.
4. **Output Classifier**: A linear classifier finally predicts the next word embedding, completing the transformation process.
### RASP
Rasp [^RASP] is a programming language consisting on functions that perform operations on sequences. The purpose of the paper that introduced it was that any program written in it (human readable) could be converted into a transformer model, by mapping the functions into transformer component weights. This mapping is done by a compiler called Tracr [^Tracr].
Next, we present the central operations of the language.
#### Select
In simple words, the select operation takes 2 arrays as input, applies a predicate function over each pair of elements, and outputs a matrix containing the results. The predicate can be any function returning boolean.
```java
// Function definition:
select(keys, queries, predicate)
// Example: (notice the 1 on the inverse diagonal)
A = select([1,2,3,4], [4,3,2,1], ==) = [[0 0 0 1],
[0 0 1 0],
[0 1 0 0],
[1 0 0 0]]
// same as: A[i][j] = predicate(queries[i], keys[j])
```
You can think of this operation as computing the attention matrix in a transformer, where the the keys and queries are passed as arrays. We elaborate on this in later sections.
#### Aggregate
The aggregate function takes an attention matrix `A` and an array `values`, and outputs the weighted average of `values` with weights taken from `A[i]`.
```java
// Function definition:
aggregate(A, values)
// Example: (reversing input)
A = [[0 0 0 0 1],
[0 0 0 1 0],
[0 0 1 0 0],
[0 1 0 0 0],
[1 0 0 0 0]]
aggregate(A, "hello") = "olleh"
// Same as: [0 0 0 0 1] * "hello" = 'o' +
// [0 0 0 1 0] * "hello" = 'l' +
// [0 0 1 0 0] * "hello" = 'l' +
// [0 1 0 0 0] * "hello" = 'e' +
// [1 0 0 0 0] * "hello" = 'h' = "olleh"
```
This corresponds to computing the attention score of each letter.
#### How Does This Correspond To Transformer Attention?
The attention we are dealing with here is a bit different from the traditional transformer, as these attention scores are categorical (can be only 0 or 1).
Furthermore, to ensure that the output attention score is also categorical, the paper uses hard attention: each query token can only attend to a single key token. In the event that there is no matching key for a query, hard attention attends to the beginning of sequence token by default. In the event that there is more than one matching key, the model attends to the closest match.
This type of attention can be defined by:
**Learned parameters:**
$keys \in K^N$, $queries \in Q^M$, $predicate$
**Attention score of word `i`:**
$z_i = \sum_j A_{ij} \cdot v_j = \sum_j predicate(q_i, k_j) \cdot v_j$
= `aggregate(select_closest(keys, queries, predicate), values)`
where `select_closest` performs the same `select` operation, but with the hard attention constraint: it selects the closest key satisfying the predicate, or the first token in the sequence if there is no match.
To offer some perspective, here is the traditional Softmax attention calculation:
$z_i = \sum_j A_{ij} \cdot v_j = \sum_j softmax(\frac{q_i \hspace{1 mm} k_j}{\sqrt{d_k}}) \cdot v_j$
Besides `select` and `aggregate`, RASP also supports arbitrary element-wise operations that correspond to the feed-forward layers in a Transformer. However, the paper doesn't address them, which is why we also don't go into them here.
## Learning Transformer Programs
### Disentangled Residual Stream
**Transformer circuits** [^TransCircuits] represent a constrained, abstract way of viewing transformer operations. In simple words, a Transformer can be seen as a series of layers that read and write information to a *residual stream*. You can think of this residual stream as a collection of memory blocks, where the output of each layer is written into. In addition, each transformer layer takes its input by reading the output of the previous layer from this residual stream, and writes its own output to its own memory blocks (this is where "disentangled" comes from - there's no shared *write* memory). The read/write operations can be defined by matrix multiplication operations, but we won't go into the mathematical details here. Feel free to take a look at the blog post ["A Mathematical Framework for Transformer Circuits"](https://transformer-circuits.pub/2021/framework/index.html).
<!--  -->
Let's consider an example in order to show how the residual stream works. Consider the simplified transformer architecture in the image below, consisting of 2 attention layers with one head, and one-hot input- and positional encodings. The dictionary consists of letters and numbers.

The corresponding residual stream is depicted on the right. Think of it as the output of each transformer layer, represented in concatenated blocks. Yellow squares mean there is a 1 at that matrix position, while blank purple space represents 0s. First block - **tokens**, encodes the dictionary tokens into their input positions using one-hot encoding. You can see that token '3' doesn't have a 1 on its row because it's not present in the input. Next, the **position** block encodes the position of each input token. Rows corresponding to the position 8-9 don't have a 1 because our input length is 8 (and counting starts from 0). The **attention** blocks don't have any written information, because the attention hasn't been computed yet. Let's compute the first layer's attention.
### "Circuits" of Attention Heads
As discussed before, the paper's version of attention is computed with: $z_i = \sum_j A_{ij} \cdot v_j = \sum_j predicate(q_i, k_j) \cdot v_j$
Or it's equivalent in RASP: `aggregate(select_closest(keys, queries, predicate), values)`.
The figure below shows the output, and computation, of the first attention layer.

Let's discuss the computation. What you see on the right in green squares is the attention matrix A. Each entry in this matrix represents attention between $query_i$ (vertical axis) and $key_j$ (horizontal axis). But what does each query/key represent? Well, this attention layer has learned to read the keys and queries arrays as the positions (0-9), and the values array as the tokens. If you notice, the one-to-one mapped attention matrix closely represents a shifted identity matrix, in particular meaning that the query token on the position i=1 will attend to the key token at the position j=0. Query position 9 will default to the `<bos>` token. If we aggregate the attention matrix over the values, this results in the tokens being shifted one position to the right.
If this attention layer is decompiled into RASP, the resulting program is:

This executes the same behaviour: the predicate makes the queries attend to the previous key positions. Also, notice how the positions array is passed as keys and queries.
This is it! This constrained Transformer model that you've seen represents a Transformer program. Its special feature is that is that it can be automatically converted into RASP human-readable code, and that it is trainable using gradient-based optimization. This feature is made possible by constraining the attention heads to be categorical, enforcing hard attention and categorical scores: {0,1}.
Of course, we covered only the basics here, there's a lot more. Please refer to the paper [^LearningTrans] for additional details on optimization, program extraction, examples and results.
## Experiments
The authors learn Transformer Programs for a variety of tasks, one of them being the suite of algorithmic tasks intoduced by [Weiss et al (2021)](https://arxiv.org/abs/2106.06981), to illustrate the RASP language. Below the specific set of results to the experiment we aim to ablate.

For this experiment, 20,000 inputs were sampled and partitioned into train, validation and test sets using the following split: 16,000/2,000/2,000 respectively.
The way they sampled these inputs without replacement is as follows: sample strings uniformly until the set of unique strings has the intended size. For all tasks, all inputs were prepended with a beginning of sequence token `bos`. Only for the "sort" and "reverse" tasks, an end-of-sequence token `eos` was appended as well.
Each model was trained for 250 epochs with a batch size of 512, and a learning rate set to 0.05. The hyperparameters `L`, `H` and `M` were presented in the table above were obtained by the authors after a grid search over the hyperparameter space. Furthermore, the attention heads are evenly divided between *numerical* and *categorical* heads (and the same applies to the categorical and numerical MLPs). Moreover, we use fixed one-hot token and position embeddings as explained earlier.
We reused training scripts and built on top of the repository created by the authors: https://github.com/princeton-nlp/TransformerPrograms. For more details on the experimental setup, see Appendix B.2 of the paper.
### Ablation 1: Different Vocabulary Sizes and Max. Input Length
#### Setup
For our first ablation, we reran the top-most five RASP tasks in the table above using different vocabulary sizes (`|V|`) and max. input lengths (`N` , which is set equal to variable cardinality `k`). We experimented with the following settings:
1. `|V| = 8`, `N = 8` (default, as used to produce Table 1 above)
2. `|V| = 8`, `N = 16`
3. `|V| = 16`, `N = 16`
For the other hyperparameters (`L`, `H`, `M`), we used the values as listed in the table above.
Note that a similar ablation as described above was reported on in the paper - see Appendix C.1. The difference is that they tuned the values for `L`, `H` and `M` by performing grid search over the hyperparameter space to obtain the best possible model on the validation set. Moreover, they also explicitly compare results against the standard Transformer.
#### Results
Below the results we obtained in terms of accuracy on the test set:
| | Reverse | Hist | Double-hist | Sort | Most freq |
| -------- | -------- | -------- | -------- | -------- | -------- |
| Setting 1 | 99.74 | 99.95 | 64.11 | 99.98 | 82.07 |
| Setting 2 | 62.84 | 99.92 | 81.74 | 99.12 | 74.75 |
| Setting 3 | 42.06 | 99.93 | 71.30 | 88.09 | 53.76 |
Similarly as reported in Appendix C.1 of the paper, we can say there is a general trend to be observed here: performance degrades moderately when trained on longer sequences and degrades even more when `|V|` is also increased (i.e. setting 3 was the most difficult one). This indeed suggests that the Transformer programs learned may not generalize well to longer input sequences and/or larger vocabularies. That is, for all five tasks except "histogram" and "double-histogram".
This pattern was also observable when considering accuracies on the *train* sets.
When comparing our results in setting 1 to the original results as presented in Table 1, they align for most tasks. However, for the "double histogram" task, they reported an accuracy of 98.40 whereas we obtained 64.11. Similarly, they reported 75.69 for "most-freq" whereas we obtained a higher score of 82.07. It could be possible that the authors slightly tweaked hyperparameters used without reporting on this in the paper.
### Ablation 2: Extending synthesized dataset size
#### Setup
In an effort to mitigate the decline in accuracy noted with increased vocabulary sizes `|V|` and maximum input lengths `N`, the same experiment as ablation 1 was replicated but with variations in the generated dataset sizes. This approach aimed to determine if enlarging the training dataset could effectively enhance the model's performance under expanded vocabulary and sequence length conditions.
#### Results

In the extended dataset size experiment, the "Histogram" task showed consistently high accuracies across all settings, indicating robustness to dataset size variations. For the "Most Frequent" and "Sort" tasks, performance remained relatively stable, with only marginal improvements noted in the "Sort" task as dataset sizes increased. Notably, the "Reverse Sequence" task demonstrated a significant positive correlation between dataset size and accuracy in more complex settings, suggesting that larger datasets effectively enhance model performance for tasks with increased complexity, though "Setting 3" did go down after 50000, presumably indicating overfitting. The "Double Histogram" task, however, displayed considerable variability with no clear trends, highlighting an unpredictable response to changes in dataset size and high sensitivity to randomness.
#### Conclusions
It is clear that in some cases more data improves the quality, but not as much as was expected before the experiment. Further research is necessary, by for instance increasing the number of heads or MLP sizes to show if these RASP tasks can be done accurately with more complex tasks.
### Ablation 3: New RASP Tasks
#### Setup
In this section we introduce 2 new tasks. We test whether we can learn Transformer Programs for two new algorithmic tasks. The first task is *Shift*, where characters are shifted one position to the left. The other task is *Run-Lenght Encoding*, where a string of consecutive repeated characters is replaced with a single character followed by the count of repetitions.
We train on small-scale instances of each task, setting the maximum sequence length and vocabulary size to 8. We use fixed one-hot token and position embeddings and set the variable cardinality k to be equal to the maximum sequence length. For this setting, we introduce numerical attention and MLPs. We equip each model with an equal number of categorical and numerical attention heads, and categorical and numerical MLPs, fixing the number of MLP input variables to two, and perform a grid-search over the number of layers, attention heads, and MLPs per-layer.
| Task name | Description | Example | k | L | H | M | Accuracy |
| --------- | ------- | -------- | -------- | -------- |-------- |-------- |-------- |
| Shift | Shift characters in a string to the left by 1 place. | "**abcd**"-> "**bcda**" | 8 | 1 | 4 | 2 | 100.0
| Run-Length encoding | Compresses a string by replacing consecutive repeated characters with a single character followed by the count of repetitions. | "**aaabcc**"->"**a3b1c2**" | 8 | 1 | 4 | 2 | 0.79
#### Shift task
We start with a simple task. Each input token is only required to attend to the previous position. We train a Transformer program with only one layer, 2 categorical and 2 numerical attention heads, and 2 MLPs. We also set the input length to 8 (including `<bos>` and `<pad>` tokens). Our results show that the model is able to achieve 100% accuracy.
Because of the simplicity of the task, we can easily interpret the model's operations that happen underneath. Consider the **input**: `<bos> a b c d e f h`. The desired output is `<bos> b c d e f h a`. Here is how the 2 categorical attention heads learned to solve this task:
<!--  -->

In the figure above, the **first attention head** attends the next, consecutive key position for each query position. An exception is position 7 - which attends to position 1, and 0 - which attends to itself. This is consistent to how the problem would be logically approached by a human: the `<bos>` token at position 0 stays in place, the token at position 1 moves to position 7 (last position), and the rest of the characters move one position to the left. The **second attention head** implements a different rule: it distinguishes between normal input and the `<bos>` token, by knowing that it is always present at position 0. Therefore, the position 0 containing the `<bos>` token gets mapped to `'0'`, while the rest are mapped to `'1'`.
For the full program containing the numerical attention and the MLP layers, please check [our repository](https://github.com/nadinekuo/TransformerPrograms).
#### Run-Length Encoding Task
Another task we explored was Run-Length Encoding, a classic problem in data compression. The goal is to compress a string by replacing consecutive repeated characters with a single character followed by the count of repetitions. For instance, "aaabcc" would be transformed into "a3b1c2". The model achieved an accuracy of 0.79 on the Run-Length Encoding task, demonstrating its ability to effectively compress strings.
The implementation involved training a transformer model with a single layer, two categorical attention heads, two numerical attention heads, and two Multilayer Perceptrons (MLPs). Special tokens such as <bos> (Beginning of String) and <pad> (Padding) were utilized to facilitate sequence processing.
Attention mechanisms allow the model to focus on different parts of the input sequence when making predictions. Attention is crucial for the model to understand the relationships between tokens in the input and output sequences.
The categorical attention heads learned to identify patterns and repetitions within the input sequence. They exhibited distinct patterns, with one head focusing on identifying repeated characters, while the other attended to transitions between characters.
The numerical attention heads, on the other hand, learned to discern the frequency and distribution of characters within the sequence. They provided complementary information to the categorical attention heads, enriching the model's understanding of the input data.
# Appendix
## Raw Accuracies Ablation 2
#### Double Histogram accuracies
| Dataset Size | Setting 1 | Setting 2 | Setting 3 |
|----------------:|-------------:|-------------:|-------------:|
| 5000 | 0.673609 | 0.493612 | 0.655454 |
| 10000 | 0.634605 | 0.532351 | 0.708046 |
| 20000 | 0.655464 | 0.709731 | 0.657033 |
| 50000 | 0.67844 | 0.55076 | 0.480676 |
| 100000 | 0.672006 | 0.636917 | 0.466006 |
#### Hist accuracies
| Dataset Size | Setting 1 | Setting 2 | Setting 3 |
|----------------:|-------------:|-------------:|-------------:|
| 5000 | 1 | 0.990843 | 0.98762 |
| 10000 | 0.999297 | 0.998771 | 0.998868 |
| 20000 | 0.999164 | 0.999482 | 0.99905 |
| 50000 | 0.999807 | 0.999319 | 0.999249 |
| 100000 | 0.999788 | 1 | 0.999788 |
#### Most Frequent accuracies
| Dataset Size | Setting 1 | Setting 2 | Setting 3 |
|----------------:|-------------:|-------------:|-------------:|
| 5000 | 0.793178 | 0.742973 | 0.486802 |
| 10000 | 0.808511 | 0.754709 | 0.495643 |
| 20000 | 0.805117 | 0.74617 | 0.490329 |
| 50000 | 0.816762 | 0.750676 | 0.505139 |
| 100000 | 0.812009 | 0.753001 | 0.505766 |
#### Reverse accuracies
| Dataset Size | Setting 1 | Setting 2 | Setting 3 |
|----------------:|-------------:|-------------:|-------------:|
| 5000 | 0.994764 | 0.412554 | 0.361345 |
| 10000 | 0.997486 | 0.556917 | 0.321156 |
| 20000 | 0.996628 | 0.656025 | 0.47219 |
| 50000 | 0.998298 | 0.757171 | 0.548778 |
| 100000 | 0.999464 | 0.780263 | 0.509709 |
#### Sort accuracies
| Dataset Size | Setting 1 | Setting 2 | Setting 3 |
|----------------:|-------------:|-------------:|-------------:|
| 5000 | 0.998504 | 0.968615 | 0.821128 |
| 10000 | 0.998922 | 0.976459 | 0.8867 |
| 20000 | 0.999875 | 0.979798 | 0.891367 |
| 50000 | 0.999905 | 0.987119 | 0.888881 |
| 100000 | 0.999821 | 0.990685 | 0.897364 |
[^LearningTrans]: Learning Transformer Programs. https://arxiv.org/abs/2306.01128.
[^RASP]: Thinking Like Transformers. https://arxiv.org/abs/2106.06981
[^Tracr]: Tracr: Compiled Transformers as a Laboratory for Interpretability. https://arxiv.org/abs/2301.05062
[^TransCircuits]: A Mathematical Framework for Transformer Circuits. https://transformer-circuits.pub/2021/framework/index.html