<style> img { display: block; margin-left: auto; margin-right: auto; } </style> > [Paper link](https://arxiv.org/abs/2002.08909) | [Note link](https://zhuanlan.zhihu.com/p/360635601) | [Code link](https://github.com/google-research/language/tree/master/language/realm) | ICML 2020 :::success **Thoughts** This paper use **asynchronous MIPS refreshes** to train the model. The backbone model is encoder-based, e.g. BERT. They pretrain the backbone model and retriever and inference with Dense Retr.+ Transformer. ::: ## Abstract To capture knowledge in a more modular and interpretable way, they augment language model pre-training with a latent *knowledge retriever*. They train knowledge retriever in an **unsupervised manner**, using - Masked language modeling as the learning signal - Backpropagating through a retrieval step that considers millions of documents ## Introduction In some previous language models, the learned world knowledge is stored implicitly in the parameters of the underlying neural network. This makes it difficult to determine what knowledge is stored in the network and where. Furthermore, storage space is limited by the size of the network. To capture knowledge in a more interpretable and modular way, they propose a novel framework, REALM. It *explicitly* exposes the role of world knowledge by asking the model to decide what knowledge to retrieve and use during inference. ![](https://hackmd.io/_uploads/SkcTczAih.png) The key intuition of REALM is to train the retriever using a *performance-based* signal from unsupervised text: a retrieval that *improves* the language model’s perplexity is helpful and should be rewarded, while an uninformative retrieval should be penalized. Previous study did not apply the framework to language model pre-training and employed non-learned retrievers to handle large-scale document collections. REALM’s retriever is designed to transfer to other tasks, and the retrieval is just text, not a labeled example. ## Background **Language model pre-training** In this paper, they focus on the *masked language model* (MLM) variant of pre-training popularized by BERT. Given an unlabeled pre-training corpus $\mathcal{X}$, a training example ($x, y$) can be generated by randomly masking tokens in a sampled piece of text. The model uses its representation of the masked input $x$ to predict the token that should go in each mask. **Open-domain question answering (Open-QA)** The “open” part of Open-QA refers to the fact that the model does not receive a pre-identified document that is known to contain the answer. This paper focuses on Open-QA systems that utilize a *textual knowledge corpus* $\mathcal{Z}$ as the knowledge source. **Many of these systems employ a retrieval-based approach: given a question $x$, retrieve potentially relevant documents $z$ from the corpus $\mathcal{Z}$, and then extract an answer $y$ from the documents.** REALM, is inspired by this paradigm and extends it to language model pre-training. ## Approach ![](https://hackmd.io/_uploads/SyloI4Cs2.png) ### REALM’s generative process For both pre-training and fine-tuning, REALM takes some input $x$ and learns a distribution $p(y \mid x)$ over possible outputs $y$. **Pre-training** - $x$: a sentence from a pre-training corpus $\mathcal{X}$ some tokens masked out - $y$: those missing tokens which will be predicted REALM decomposes $p(y \mid x)$ into two steps: - Retrieve: - Given an input $x$ - Retrieve possibly helpful documents $z$ from a knowledge corpus $\mathcal{Z}$. - Predict: - Model this as a sample from the distribution $p(z \mid x)$ - Condition on both the retrieved $z$ and the original input $x$ to generate the output $y$ **Fine-tuning** - $x$: a question - $y$: the answer $$ \tag{1} p(y \mid x) = \sum_{z \in \mathcal{Z}} p(y \mid z, x) p(z \mid x) $$ ### Model architecture **Knowledge Retriever** The retriever is defined using a dense inner product model: $$ \begin{aligned} p(z \mid x) & =\frac{\exp f(x, z)}{\sum_{z^{\prime}} \exp f\left(x, z^{\prime}\right)} \\ f(x, z) & =\text { Embed }_{\text {input }}(x)^{\top} \text { Embed }_{\text {doc }}(z) \end{aligned} $$ where $\text{Embed}_{\text{input}}$ and $\text{Embed}_{\text{doc}}$ are embedding functions that map $x$ and $z$ respectively to $d$-dimensional vectors. The *relevance score* $f(x, z)$ between $x$ and $z$ is defined as the inner product of the vector embeddings. How the $\text{Embed}_{\text{input}}$ and $\text{Embed}_{\text{doc}}$ come from ? $$ \begin{aligned} \operatorname{join}_{\mathrm{BERT}}(x) & =[\mathrm{CLS}] x[\mathrm{SEP}] \\ \operatorname{join}_{\mathrm{BERT}}\left(x_1, x_2\right) & =[\mathrm{CLS}] x_1[\mathrm{SEP}] x_2[\mathrm{SEP}] \end{aligned} $$ $$ \begin{aligned} \operatorname{Embed}_{\text {input }}(x) & =\mathbf{W}_{\text {input }} \operatorname{BERT}_{\text {CLS }}\left(\operatorname{join}_{\text {BERT }}(x)\right) \\ \operatorname{Embed}_{\text {doc }}(z) & =\mathbf{W}_{\text {doc }} \operatorname{BERT}_{\text {CLS }}\left(\text { join }_{\text {BERT }}\left(z_{\text {title }}, z_{\text {body }}\right)\right) \end{aligned} $$ **Knowledge-Augmented Encoder** Given an input $x$ and a retrieved document $z$, the knowledge-augmented encoder defines $p(y \mid z, x)$. For the masked language model pre-training task, we must predict the original value of each $[\text{MASK}]$ token in $x$. So they use same masked language modeling (MLM) loss $$ \begin{aligned} p(y \mid z, x) & =\prod_{j=1}^{J_x} p\left(y_j \mid z, x\right) \\ p\left(y_j \mid z, x\right) & \propto \exp \left(w_j^{\top} \operatorname{BERT}_{\text {MASK }(j)}\left(\text { join }_{\text {BERT }}\left(x, z_{\text {body }}\right)\right)\right) \end{aligned} $$ Let $S(x, y)$ be the set of spans matching $y$ in $z$. $$ \begin{aligned} p(y \mid z, x) & \propto \sum_{s \in S(z, y)} \exp \left(\operatorname{MLP}\left(\left[h_{\mathrm{START}(\mathbf{s})} ; h_{\mathrm{END}(\mathbf{s})}\right]\right)\right) \\ h_{\mathrm{START}(\mathbf{s})} & =\mathrm{BERT}_{\mathrm{START}(\mathbf{s})}\left(\operatorname{join}_{\mathrm{BERT}}\left(x, z_{\mathrm{body}}\right)\right), \\ h_{\mathrm{END}(\mathbf{s})} & =\operatorname{BERT}_{\mathrm{END}(\mathbf{s})}\left(\operatorname{join}_{\mathrm{BERT}}\left(x, z_{\text {body }}\right)\right), \end{aligned} $$ ### Training For both pre-training and fine-tuning, they train by maximizing the log-likelihood $\log p(y \mid x)$ of the correct output $y$, since both knowledge retriever and knowledge-augmented encoder are differentiable neural networks, and compute the gradient of $\log p(y \mid x)$ with $\theta$ and $\phi$. The method to find top $k$ documents efficietly, using “refresh” the index by asynchronously re-embedding and re-indexing all documents every several hundred training steps. ![](https://hackmd.io/_uploads/BkuwGC1nn.png) **What does the retriever learn?** For a given query $x$ and document $z$, recall that $f(x, z)$ is the “relevance score” that the knowledge retriever assigns to document $z$. How a single step of gradient descent during REALM pre-training: $$ \begin{aligned} \nabla \log p(y \mid x) & =\sum_{z \in \mathcal{Z}} r(z) \nabla f(x, z) \\ r(z) & =\left[\frac{p(y \mid z, x)}{p(y \mid x)}-1\right] p(z \mid x) . \end{aligned} $$ For each document $z$, the gradient encourages the retriever to change the score $f(x, z)$ by $r(z)$ — increasing if $r(z)$ is positive, and decreasing if negative. The multiplier $r(z)$ is positive if and only if $p(y \mid z, x) > p(y \mid x)$. ### Injecting inductive biases into pre-training **Salient span masking** During REALM pre-training, they want to focus on examples $x$ that require world knowledge to predict the masked tokens. **Null document** They add an empty null document $∅$ to the top $k$ retrieved documents, allowing appropriate credit to be assigned to a consistent sink when no retrieval is necessary. **Prohibiting trivial retrievals** They exclude the trivial candidate which if the masked sentence $x$ comes from document $z$ occurs too often. **Initialization** They warm-start $\text { Embed }_{\text {input }}$ and $\text { Embed }_{\text {doc }}$ using a simple training objective known as the Inverse Cloze Task (ICT) where, given a sentence, the model is trained to retrieve the document where that sentence came from. For the knowledge-augmented encoder, we warm-start it with BERT pre-training—specifically, the uncased BERT-base model. ## Experiments ![](https://hackmd.io/_uploads/rkRzjCy3h.png) ## Future Work 1. Structured knowledge, which would result in a generalization of where we would also learn the decision of which entities are informative 2. The multi-lingual setting, e.g., retrieving knowledge in a high-resource language to better represent text in a low-resource language 3. The multi-modal setting, e.g., retrieving images or videos that can provide knowledge rarely observed in text