$$
\newcommand{\R}{\mathbb{R}}
$$
# Transformer State Size Scaling Laws
Two ways in which we can think of a standard transformer as an RNN.
* KV cache
* infinite dimensional embedding
These to formulations are "mathematical", meaning that they don't naturally translate into an RNN
architecture we can implement efficiently in hardware.
But either of these perspective inspires small modifications of the transformer equations so that
efficient hardware implemetations become posible.
* windowed transformers. Efective context length for window size $w$ and number of layers $l$ is $wl$.
* Embedding functions $\phi$ of finite dimensionality. These architectures have infinite ontext lenght. Implementing them efficiently on hardware require some consideration. We explored the central ideas in our previous work.
In the article we will explore how to
The size of the state is a fundamental constrant when thinking aobout efficient implementations. Frist, if the state is too big it won't fit in the GPU memory. Second, one of the major bottlenecks is memory transfers between the RAM and the cores where ocmputation is done. W
PLOT: state size scaling laws of 1.6B model
The most promising is multilinear. A model that we descrive in detail in this associated post, where we also implement an efficient recuret implementation.
## Embeddings and Similarities
For $Q_i, K_j\in \R^d$. A
$f(Q_i, K_j) = \phi(Q_i) \cdot \phi(K_i)$
The key
$$
e^{Q_i^T K_j} = \sum_k \frac{2}{k!} \left(\sum_j Q_{ij} K_{ij} \right)^k =
$$
One obvious idea to have finite dimensionality would be to usee a taylor approximation of the exponential.
----
----
the focus has been on number of parameters. But here we show dimensionaliry of the state is a criticsl factor controling the performance of a transformer.
For explore a few variatios of transformers for which we can vary the state size, and measure the performance on the LRPJ2. Maximum context during training was 65k.
**KV cache transformer** Effectively, the x dim of the state just controls an artificail reduction of the context size. TODO: brief explanation of arch
PLOT: state caling law
**taylor transformer** Appriximate the stantand transformer with the taylor expansion. Becomes a linear transformer. The second order taylor approx has been used in [CITE][CITE] and found to imporove the performance of linear transformers. Further explained section 2
PLOT: state scaling law
**head-varied linear transformer** TODO: brief explanation of arch. A fundamental limitation of this way of increasing the state is that there is a maximum. For GPT2XL with 1.5B parameters, the largest state we can get is 122 million. Still, we can draw a scaling law:
PLOT: state scaling law
**multilinear transformer** our new architecture that we explain in this acompanying blog post. Mathematically can also be thought as a linear transformer.
PLOT: state scaling law
There is a clear pattern of improved performance as the state size grows. Thus, one must find a good tradeoff between the size of the model and the dimensionality of the state.
All of these architectures can be unrolled arbitrarily long. They are RNNs in a way. A central factor of the performance of any of their implementation is how much data you have to read. To implement chunked, one must read the state form HBM to the cores. The size of the state will be the bottleneck. Thus we should look for architectures that get the best loss for the same state dim. Afterwards we will know which one is worth implementing an efficient chunked version.
CENTRAOL PLOT: x axis is state dim and y axis is final performance after same number of updates. Every line corresponds to a different architecture. Perhaps add a tab that let's you pick the model size.
Since the multilinear smokes the rest you can explect a blog post where we are going to see what a highly optimized chunked multiliear transformer can do.
To us, seeing these state scaling laws gave us a retroactive perspective on one of the factors that made transformes so successful. In 2019, they most likely were the only models that had ever been trained with a state size (KV cache) on the order billions. People had used the equivalent resources to train an infinitude of LSTMs variations, but all of them had tiny states relative to the KV cache of a transformer.
It's funny because we went directly from architectures with tiny states to the limit of state size = infinity. This was possible thanks to the attention formulation and came at the cost of using relatively small context sizes for training. But since RNNs are so inneficient on the GPU, they also couldn't handle very long training context sizes.
## Description of The architectures
## Context size and State size Laws
The the more long term structure of the data, the largest the context size scaling laws. From LRPJ2 we get squences of length T, Look at how the scaling alws of XL change as we vary T. First, at 1k we see more state helps, but
PLOT: take only cached-RPJ2 and multilinear at 1k. Scaling curve on the right. Selected learing curves on the left
That is very different from the laws we saw before:
PLOT: take only cached-RPJ2 and multilinear at 64k. Scaling curve on the right. Selected learing curves on the left
Browse with the slider and see how we increasing the length of the documents we use smoothly grows the magnitude of the state scaling law.
PLOT: take only cached-RPJ2 and multilinear with a slider to control T. Scaling curve on the right. Selected learing curves on the left
ALTERNATIVE PLOT: Heat map. Y axis training context size. X axis the state size. Color is the loss at 64k eval.
Rule of thumb, one should double the state size every time the context is quadrupled
The gains from training with an infinite size state are worth the cost of using a t^2 training algorithm and therfore having to train on smaller contexts? Are linear transformer architectures that allow chunked worth it? It will depend on how much faster than the KV cached versions they can run at different context sizes. In the next post we will push this issue.
## Model Size and State size Laws
We also study how the model size affects the state size scaling laws. A priory, it seems quite intuitive that going from state size 1M -> 1B should have a smaller effect on a tiny model with 1M parameters than on an large 1B transformer. The large model can has the power to better utilize that large state blah blah blah. We can just measure if this is true.
If we only allow the runs 6h, we can see too much state or too much context end up hurting because the lower number of training updates.
PLOT: Heat map. Y axis training context size. X axis the state size. Color is the loss at 64k eval. All after 6h
But the more time we give them, the optimum option is to increase both values by a certain amount
PLOT: Heat map. Y axis training context size. X axis the state size. Color is the loss at 64k eval. All after 24h
You can play with this interactive version seeing how the optimum varyies depending on the training resources.
PLOT: Heat map. Y axis training context size. X axis the state size. Color is the loss at 64k eval. A slider let's you pick the training hours.
As a loose rule of thum. We propose that **the state size should be of the same order as the parameter count**. This is very far away from the case for architectures like classic RNNs, LSTMs etc... where the state size was tied to the parameter one in way that it was always much smaller.