<style>
img {
display: block;
margin-left: auto;
margin-right: auto;
}
</style>
> [Paper link](https://arxiv.org/abs/2102.02557) | [Note link](https://zhuanlan.zhihu.com/p/455152768) | TACL 2021
:::success
**Thoughts**
This paper uses not only Long-term memory but also Short-term memory to generate the next token from previous $\tilde{\boldsymbol{x}}_{\le t - N}$ token. And they use ScANN to do MIPS faster. But end-to-end training isn't use in this paper.
:::
## Abstract
They present a language model that combines a large parametric neural network with a non-parametric episodic memory component in an integrated architecture.
They extend their model with **short-term context by caching local hidden states** and **global long-term memory by retrieving a set of nearest neighbor tokens at each timestep.**
So, they design a **gating function** to adaptively combine multiple information sources to make a prediction. This mechanism allows the model to use either local context, short-term memory, or long-term memory (or any combination of them) on an ad hoc basis depending on the context.
## Introduction
In this paper, inspired by the **modular design of human memory systems**, they present a language model architecture (SPALM) with storage modules that resemble working and episodic memory systems, which it combines with a large parametric neural network that is responsible for computation.
Their hypothesis is that encouraging each component to focus on a specific function (e.g., storing long-term information, capturing extended context, modeling local information) facilitates easier training that produces an overall better language model.
- Short-term memory: transformer-XL
- Long-term memory: persistent key-value database and perform sparse retrieval with (approximate) k-nearest neighbors
## Model
The input consider as a sequence of words $\boldsymbol{x}_{\le t} = \{x_0, \dots, x_t\}$ and outputs a probability distribution of the next word $p(x_{t+1} \mid \boldsymbol{x}_{\le t} ; \mathbf{W})$. With given corpus of $T$ words, the log likelihood of the corpus is:
$$
\mathcal{L} = \sum_{t=0}^{T} \log p(x_{t+1} \mid \boldsymbol{x}_{\le t} ; \mathbf{W})
$$
where $x_0$ is the start of sentence symbol.

### Base model
Transformer is their base model. A core limitation of transformer is that its computational complexity is quadratic in the input sequence length.
As a result, instead of considering all previous tokens $\boldsymbol{x}_{\le t}$, transformer truncates the input to be the most recent $N$ words $\tilde{\boldsymbol{x}}_{\le t} = \{ x_{t-N+1}, \dots, x_t\}$ and only operates on this fixed-length window in practice.
### Short-term memory
Transformer-XL is their working memory model.
Given the current context $\tilde{\boldsymbol{x}}_{\le t}$, denote the extended context of length $M$ by $\tilde{\boldsymbol{x}}_{\le t - N} = \{ x_{t-N-M+1}, \dots, x_{t-N} \}$. In transformer-XL, hidden states for $\tilde{\boldsymbol{x}}_{\le t - N}$ are cached.
And it can be attended to during the forward pass when computing hidden states for the current context $\tilde{\boldsymbol{x}}_{\le t}$, but the values of the states are not updated during the backward pass to save computation time.
### Long-term memory
The long-term memory module is implemented as a key-value database. The key is a vector representation of a context $\tilde{\boldsymbol{x}}_{\le i}$.
Each context is paired with the output token for that context $x_{\le i+1}$, which is stored as the value. In their experiments, they store a key-value entry for each context-token pair in the training corpus.
The gating mechanism as follows:
$$
\begin{aligned}
& \mathbf{m}_t=\sum_{k=1}^K \frac{\exp \mathbf{y}_k^{\top} \mathbf{h}_t^R}{\sum_{j=1}^K \exp \mathbf{y}_j^{\top} \mathbf{h}_t^R} \mathbf{y}_k \\
& \mathbf{g}_t=\sigma\left(\mathbf{w}_g^{\top} \mathbf{h}_t^R\right) \\
& \mathbf{z}_t=\left(1-\mathbf{g}_t\right) \odot \mathbf{m}_t+\mathbf{g}_t \odot \mathbf{h}_t^R \\
& p\left(x_{t+1} \mid \boldsymbol{x}_{\leq t}\right)=\operatorname{softmax}\left(\mathbf{z}_t ; \mathbf{W}\right),
\end{aligned}
$$
where $y_1 \dots y_K$ is the output tokens retrieved from the database, $\mathbf{w}_g$ is a parameter vector, $\sigma$ is a sigmoid function, and $\mathbf{W}$ is the word embedding matrix that is shared for input and output word embeddings.
### Training details
1. Train a standard transformer language model and use it as an encoder to compute key representations $\mathbf{d}_i$ for the episodic memory database.
2. Do not update the key representations when training the overall model, it allows us to fix the set of nearest neighbors for each token.
3. The value encoder, on the other hand, is updated during training since we use the word embedding matrix to represent.
4. ScANN.
## Comparisons to previous work
**$k$NN-LM.**
It's a language model that is augmented with a nearest neighbor retrieval mechanism.
**Cache-based language models and pointer networks.**
Cache-based language models store pairs of hidden states and output tokens from previously seen tokens (within a limited context length) in a cache.
## Experiments
They use word-based and character-based English language model datasets, WikiText-103, WMT, and enwik8 to evaluate our proposed method.
**WikiText-103**

**WMT**

**enwik8**

## Discussion
They present a semi-parametric language model (SPALM) that combines local context, short-term memory, and long-term memory to make predictions.
The biggest limitation is the necessity to retrieve neighbors for each training token. Such a process—even though can be fully parallelized—is time consuming.