---
title: "On Context Scaling"
format:
html:
toc: true
toc-location: left
toc-title: " "
author:
- Jacob Buckman
- Carles Gelada
date: "2024-01-05"
categories: false
draft: false
citation:
type: article # See available types here https://docs.citationstyles.org/en/stable/specification.html#appendix-iii-types
# abstract:
# editor:
publisher: "Manifest AI"
bibliography: references.bib
csl: ieee.csl
google-scholar: true # This creates metadata for easy indexing by google scholar and other engines.
---
Intuitively, extending the context should make the task of predicting the next
token strictly easier. If the extra information is relevant, the model should
learn to use it; if it isn't, the model can just ignore it. Therefore, given any
well-trained language model, we expect the average loss to be lower on
longer-context predictions. To verify this, we trained a 1.6 B transformer [^2]
mdoel with context size 8192. Figure 1.A shows the average loss of the model at
different context lengths. This same effect has been observed in the literature.
For example, in figure 1.B you can see a similar plot from Gemini
[@geminiteam2024gemini]. We refer to this phenomenon as *inference time context
scaling*.
[^2]: Our architecture is almost identical to nanogpt, with one difference: we
use rotary position embeddings [@su2023roformer] rather than fixed positional
encodings. This is so we can evaluate generalization beyond the training
context. The dataset used for training and evaluation is LongCrawl.
:::{.column-body style="text-align: center;"}
<img src="plots/gemini_scaling.png" alt="Image" style="max-width:50%; height:auto;">
<iframe width="400" height="200" src="plots/inference_context_scaling.html"></iframe>
:::
Inference-time context scaling is the ultimate reason that training with longer
context should be beneficial to model performance. There is a nice analogy
between parameter count and token count. It is well known that increasing the
parameter count improves learning, but there are different ways one can scale
the parameters of a model. Two major fators are the depth and width of the
network. To optimally scale a model, one must find a good balanced between the
two. Similarly, it is well known that increasing the amount of training tokens
improves performance. But extra tokens can come in the form of more documents or
longer documents. We expet there to be an optimal tradeoff of count vs length
similar to the one we see for width vs depth.
PLOT: little visual aid for the analogy
Framed in this way, it makes a lot of sense to try and pick the context size the
same way we pick other hyper parameters like network depth and width: via
scaling laws. The idea is to first run small-scale experiments in order to
measure how the loss and compute cost scale with respect to the hyperparameters
of interest. Then extrapolate these curves to obtain an estimate of the
performance of the model for any configuration. Finally, find the optimal
setting given available computational resources. For example,
[@kaplan2020scaling] and [@hoffmann2022training] where looking for the optimal
way to scale the model size and the ammount of tokens seen during training. In
both papers the model size was controlled by varying both, the depth and width
of the model. But the context size was picked independently of any laws and was
held constant throughout all the experiments. Why wouldn't we also find the
optimal value of the context size in terms of performance per $?
In section 3 we will study the impact of context scaling of a transformer
architecture, finding the optimal way to grow the context and model size. But
before that, there are a few challanges we must face first. In section 1 we will
discuss why current datasets are inadecaute and will present our solution, the
Long Crawl dataset. Section 2 takles another issue: comparing the training
losses of models trained with different context sizes is missleading. There are
cases where models which are strictly worse than others have a lower training
loss! This is a problem if we are trying to to find the combination of
hyperparameters that attain the lowest loss at the end of training. Luckily,
there is a closely related metric that doesn't suffer from this issues and can
be used to draw scaling laws for all hyperparameters like network depth, network
width, document count and document length.
::: {.column-margin}
This mindset of picking the context size via scaling laws stands in stark
contrast to the way a lot of the research in long context is often framed. Many
papers propose algorithmic or architecture modifications that allow training
models with very long context sizes. Of course, being able to train models with
long contexts is an important goal, but is not enough. One must also show that
it is a smart decision to do so (in terms of performance-per-dollar).
:::
A major conlusion of section 3 is that, for the transformer architecture under
evaluation, one wants to grow the context size quite slowly compared to the
model size. It really wouldn't make economic sense to train it on very long
contexts, even for very large computational budgets.
But ultimately we don't really care about the context size used during training.
All we want is a model with good inference time context scaling so that users of
the model can get very good predictions by feeding lot's of relevant information
into the context. Thus, having to train the model on short contexts might not be
a problem as long as the inference time context scaling keeps improving beyond
the training context size. Unfortunately, in the experiments of section 2, we can
see that this isn't the case for the transformer archtecture that we've been
evaluating.
The combination of "it doesn't make sense to train models on long context" and
"the model performance degrades beyond the training context size" means the
particular architecture under study has a bleak prospect for long context. Until
we eimpirically test it, we will stop short of claiming that these trends apply
to more modern transformer architectures, although we suspect they do. In any
case, we believe that carefully evaluating the context scaling properties of
models will be essential to further push the fronteer of context size. We hope
that the dataset, ideas and evaluations presented in this article will prove
useful to that objective.
## Data with Long Term Structure
Below you can see an inference time context scaling plot just like the one on the
introduction. It shows the average loss at every context length for a 1.6B-parameter
model trained using 8KiT of context on
[openwebtext](https://github.com/jcpeterson/openwebtext).
:::{.column-body style="text-align: center;"}
<iframe width="700" height="300" src="plots/openwebtext64_inference_scaling.html"></iframe>
:::
As you can see, at the beginning, this plot looks like the previous one. But
very quickly (at around 2KiT) the performance tapers off and the there is
no benefit from including extra information in the context. To understand the
reason, we just need to look a the document length distirbution of openwebtext.
:::{.column-body style="text-align: center;"}
<iframe width="700" height="250" src="plots/owt_histogram.html"></iframe>
:::
The vast majority of the documents are less than 2k long. So, when we want to
train 8k-context models on this dataset we are forced to artificially construct
longer documents; in our experiments, we concatenated multiple documents into a
longer one. But the resulting documents do not contain any long-term structure,
so there is no benefit to seeing more tokens at inference-time.
This problem is not restricted to openwebtext. Most popular datasets, like
[][][][] have similar document length distributions. So clearly, in order to
study the context scaling properties of any algorithm, we are going to need
better data. That is why we created [LongCrawl64](), a large natural langauge
dataset composed entirely of documents of length 64k.
To construt it, we distilled
[RedPajama-v2](https://github.com/togethercomputer/RedPajama-Data) downt to it's
documents of length >= 65336. We then tokenize each one with [OpenAI's
TikToken](https://github.com/openai/tiktoken) tokenizer for GPT-2, which uses
BPE with a vocabulary size of 50304. The end result is a 6661465 x 65336
[Zarr]() array of uint16s, representing 6,661,465 documents. The total token
count is ~435 billion, two orders of magnitude larger than openwebtext (~6
billion). Read the assosiated article to see all the details around the
construction and usage of the dataset, like how we use the 64k long documents
when trainig moders with shorter context lengths.
Armed with this new dataset, we can repeat the same experiment with this new
dataset. The inference time context scaling plot of the 1.6B transformer with context size 8k on Long Crawl
looks like this:
:::{.column-body style="text-align: center;"}
<iframe width="700" height="300" src="plots/longcrawl64_inference_scaling.html"></iframe>
:::
PLOT: inference time context scaling with Long Crawl and openwebtext. Two
different colors for each dataset.
We can see clear benefits form increasing the context size all the way up to 8k.
With this issue resolved, let's move onto the next one.
## The Training Loss is Misleading
Look at the inference-time context scaling for two 1.6B transformers after
trainng on LongCrawl for 2 days on a machine with 8xH100 GPUs. The only
difference is that one is trained with 8KiT and the other with 32KiT.
PLOT: contextwise losses for the same model trained with different context sizes
only show the losses up to the training context size. Consider
putting a dotted horzontal line for the average loss of each one
Clearly, the 8KiT model is superior. It makes better predictions than the 32KiT
model at every context length where they can be compared and also, at the tail,
the 8KiT model reaches a lower loss than the 32KiT model ever does. And yet, the
average training loss of the 32KiT model is X, while the 8KiT model has loss Y.
This is because the training loss is effectively the average loss of the
inference-time context scaling plot, which means that for the 32KiT model, a
much larger fraction of the training data comes from situations where the model
has a large amount of information in its context. **When comparing models
trained with different context sizes, the training loss is a poor heuristic for
their performance.** So, if we want to use the standard scaling laws mindset to
select hyperparameters (and the training context size is among them), we must
find a better metric to optimize.
::: {.column-margin}
At inference time, even if we had 32KiT of context to give to the model, instead
of using the 32KiT model we would be better off _throwing away_ the first 24KiT
and feeding the remainder to our 8KiT model. That would result in better
prediction.
:::
Intuitively, we want our metric to refect how good the model would be at
predicting the next token at inference time. If we make the assumption that the
users of the model have acess to arbitrarily many tokens to put in the context,
then, it would seem that a good metric would be the lowest loss that the model
attains at any context size. We refer to this as the **best context loss**. The
idea would be that users would provide however much context as necessary to get
the best predictions from the model. The most natural way to measure it would be
to start with the contextwise loss curve on a dataset made out of very long
documents. Then, the best-context loss would just be the minimum of the curve.
Since the transformer we've been working with uses rotary embeddings, we can
evaluate it beyond it's training context. And, with the LongCrawl64
dataset, we have data with long term structure up to 64KiT. Thus, we can
generalize the inference-time context scaling plots from before up to 64KiT.
:::{.column-body style="text-align: center;"}
<img src="plots/inference_time_law_beyond_training.png" alt="Image" style="max-width:70%; height:auto;">
:::
PLOT: Include lower bound line denoting the best possible performance of the
model.
One interesting thing to note is that when we pass the line demarcating the
context size used during training there is a rapid deterioration of the
prediciton quality. Clearly, this GPT2+RotaryEmeddings model does not generalize
very well beyond it's trainig context. We've obvserved this exact same phenomenon
for transformers of all sizes trained on all sorts of context sizes.
::: {.column-margin}
Even though there have been a lot of claims of generatization beyond the context
size [][][], to the best of our knowledge, nobody has shown a model for which
the loss monotonically decreases with the context size on natural data,
approaching a limit loss. One would want to see that the model is making much
better predictions when it has all the information it needs, than it ever did
during training. We consider this to be the **true generalization beyond
training context**.
:::
The fact that models attain their best peformance right at the end of the
training context simplifies a lot measuring the best-context loss. There is no
need to evaluate our model on a different dataset with longer contexts because
the training process already produces the measurement we need. Instead of just
logging the training loss, we also include the average final loss of the
minibathc. [^7] So, as long as we are working with algorithms that generalize
poorly beyond the training context, we can use this trick to measure the
best-context loss.
[^7]: In practice, we take the average loss for the last 5% of the training context.
## Context Scaling Experiments
In the last two sections, we have seen a clear benefit to using longer training context on the tail loss. But there also are some clear tradeoffs:
* backprop reason: Bigger batch size: The batch size is always a multiple of the
context size. This goes against the whole idea of SGD. It's worth noting that
even though this is a general drawback of large context sizes, it isn't showing
in this set of experiments because we kept the batch size constant. That
decision can be seen as "unfair" to small context sizes since they are using
much larger batch sizes than optimal.
* transformer reason: $O(t^2)$ costs. This
wouldn't be the case for RNN architectures like [linear
transformers](https://manifestai.com/blogposts/faster-after-all/), and the third
we will tackle in a future post.
* Fundamental: Less data diversity. Intuitively, limit of training with a
single, extremely long document, would not be a good
It is clear that there will be an optimal tradeoff and our goal now is empirically
measure this sweet spot and get a sense of how it changes as we vary the amount of training resources.
The basic exepriment we conducted is: train GPT-2 + rotary with flash attention for scales
small, medium, large, xl at context sizes 128, 512, 2048, 8k, 32k. Each run used 8 H100 GPUS with data parallel for 2 days or 50k gradient updates. We kept the batch size constant (number of tokens per gradient step), so that, for different context sizes, the number of documents in each update varied from 2048 -> 8.
Let's first look at the small transformer:
PLOT: Final losses for the small transformer
Takeaways:
* Context size tends to draw a U curve at all resource levels. Picking too small or too large a context size results in severely degraded performance.
* The optimal context size grows with the training resources. Towards 50h it finally makes sense to train the model with 32k context size.
We can plot a line like this for every model scale so we can think about the combination of both, model and context scale.
PLOT: the tail loss at the end of training for GPT2 with varying context sizes, each model size gets its own line. Include slider that picks the number of training hours.
Takeaways:
* The best choice is But after the 15h one would want to use the largest model. Including larger models into the experiment would have required some engineering effort we didn't want to do.
* With few GPU hours smaller models have better performance (because they did a lot more updates). During 1h-15h neither the largest or smallest model is the optimal choice. It's only after 15h that it makes sense to train the XL model. This is the standard parameter scaling behaviour. [^10]
* The optimal context size for bigger models tends to be smaller than the one from smaller models. This implies that when you optimize for the combination of model and context scale, the context scale moves right much more slowly than it does for a fixed model scale.
We can plot the data in a different way. For each model size, we can plot the optimal contex size in the y axis. (at the boundries, the optimal context size tends to jump around so we smothed it). It gives us a sense of how one wants to grow the context size as the model size is increasing.
PLOT: with train time on the x axis, plot the optimal-model-at-each-moment's depth/width ratio vs doclength/count ratio. We will see the first is (linearly?) growing, but the second has an asymptote.
## Final Thoughts On Context Scaling
At every order of magnitude
longer context enables new applications of the technology. For example:
* **kilotoken** scale: Read & write emails. Hold a short chatbot-style conversation.
Customize behavior with a prompt.
Few-shot learning based on up to a few dozen examples.
* **megatoken** scale: Write books. Review news articles. Read & edit code.
Correctly source answers from a large scientific literature.
Few-shot skill acquisition from demonstration.
* **gigatoken** scale: Read everything tweeted in a day and synthesize global opinion.
Execute a full software engineering workflow.
In-context learning of entire datasets (replacing fine-tuning).
* **eratoken** scale:
* **petatoken** scale:
The combination of: "it doesn't make sense to train models on long context" and
"the model performance deteriorates beyond the training context size" paints a
bleak picture for the future of long context. At least for the current LM stack
based on transformers trained with backprop.
If we want to unlock the true
potential of neural network sequence models, we must make progress in these two
problems. We hope that LongCrawl and the general ideas discussed in this article
will help this end goal.
```{=html}
<form
action="https://buttondown.email/api/emails/embed-subscribe/manifestai"
method="post"
target="popupwindow"
onsubmit="window.open('https://buttondown.email/manifestai', 'popupwindow')"
class="embeddable-buttondown-form"
>
<label for="bd-email" style="font-weight:bold;margin-right:20px;margin-top:20px;">
Subscribe to be notified of new posts:
</label>
<input style="width: 35%;" type="email" name="email" id="bd-email" placeholder="Email"/>
<input style="width: 80px;" type="submit" value="Subscribe" />
</form>
```
### Acknowledgments {.appendix}
We would like to thank:
Jono Ridgway for helping to prepare the release;
...