## Retrieval Augmented LM with Sparse KV Cache ### Key Idea & Motivation Implement a sparse KV cache that can be used for both retrieval and generation, addressing the need for **infinite context** and **unified retrieval and generation** processes. ### Formulation Given a text of length $N$, with token representations $X \in \mathbb{R}^{N \times d}$, features after projection layers in attention modules are key representations: $K \in \mathbb{R}^{N \times d}$, query representations: $Q \in \mathbb{R}^{N \times d}$, and value representations: $V \in \mathbb{R}^{N \times d}$, respectively. The main overhead for generation comes from the massive storage required for KV cache, also the expensive attention computation over it with increasing $N$. For efficient generation, there have been a line of work focusing on solutions based on sparse KV cache ([FastGen (Ge et al. ICLR '24)](https://openreview.net/pdf?id=uNrFpDPMyo), [StreamingLLM (Xiao et al. ICLR'24)](https://arxiv.org/pdf/2309.17453)). They sparsify KV cache by dropping unimportant tokens based on heuristics or attentions, thus maintaining small KV cache during generation. Based on our study on attention before, the attention computation also shows dimension-wise sparsity. The attention scores before softmax are additive across dimensions as $\sum_i q_i k_i^T$, where $q_i\in\mathbb{R}^{N\times1}$ and $k_i\in\mathbb{R}^{N\times1}$ are i-th columns (a specific dimension) in $Q$ and $K$. Each element in $q_i$ and $k_i$ is a matching score (inner-product) between the corresponding token representation and i-th columns in projection weights $W_q$ and $W_k$. Except for some near-uniform attention distribution, elements in $K$ with small magnitude will not trigger high attention scores, thus can be dropped for sparsity. ![image](https://hackmd.io/_uploads/BJNQo-u2C.jpg) Specifically, wlog. we can generate a binary sparsity mask $M$ based on magnitudes of $K$ as $M=\mathbf{1}_{-\delta < K <\delta} \in \mathbb{R}^{N\times d}$, where $\delta$ is a hyper-parameter for sparsity (there are other design choices like dim-wise top-k selection). Now we can save the sparse KV cache as $\hat{K}=K\circ M$, $\hat{V} = V\circ M_\text{col-max}$. $\circ$ denotes element-wise product and $M_\text{col-max}\in\mathbb{R}^{N\times d}$ is a mask with maximum broadcast to all dimensions for each token. #### Generation The generation process is naive as conditional generation based on retrieved KV cache. A design choice to consider is whether we need to store and inject KV cache for a single layer or multiple layers. #### Retrieval We perform retrieval as an approximation to attention in original transformer with sparse computation, making the unified retrieval and generation combined more naturally. We are essentially performing the second pass sparsification given additional information in $Q$. Note that $Q$ is also sparse with certain dimensions not activated for the specific query. Thus chunks with no shared active dimensions can be dropped. Checking only a few non-zero elements reminds us of sparse retrieval approaches (e.g., splade, bm25). We can implement it in a similar way with each dimension being an abstract term in the inverted index like $i\rightarrow(\text{docID}_j, \max \hat{k}^j_i)\rightarrow\dots$, where $\hat{k}^j_i$ denotes the i-th column in $\hat{K}$ for the j-th document. ### Main Features and Expected Contributions #### Reduced Computation and Storage #### Adaptive Compression - Compression ratios for chunks become adaptive based on information density. - Dense information chunks: each token may have at least one dimension activated. - Low information chunks: most tokens have weak activations in all dimensions. #### Cross-doc Interactions Compared to previous unified retrieval and generation models like [GritLM (Muennighoff et. al. '24)](https://arxiv.org/pdf/2402.09906), this approach has the potential to model cross-doc interactions, by preserving a few layers for global interactions among retrieved sparse representations. #### End-to-end Optimization for RAG For fine-tuning on RAG-oriented scenarios, besides the option of separated objectives for retrieval (contrastive loss with proper scoring function) and generation (NLL loss), we can tune the model with an end-to-end objective: $$\mathcal{L} = -\log p(a|C^+, C^-,q) + \beta \cdot ||\texttt{ATTN}(C^{-}, q)||.$$ $C^+, C^-, q, a$ are positive context, negative context, question, and answer, respectively. $\beta$ is the regularization factor for attention to negative context. ### Potential Challenges - Conflicts between utilizing existing pre-trained weights and offloading cross-doc interactions to only deeper layers - Balancing sparsity and model performance ### Preliminary Study **Context** The Great Wall of China is an ancient series of walls and fortifications located in northern China, built around 500 years ago. Contrary to popular belief, the Great Wall is not a single continuous wall but a collection of walls built by different Chinese dynasties. The most well-known and best-preserved section was built during the Ming Dynasty (1368-1644). **Query** The Great Wall of China was built around **Baselines** We prompt Llama3.1-8B for text continuation. The output distributions with and without context are | Step 1 | Step 2 | Step 3 | |--------|--------|--------| | 500 (0.7025) | years (0.6992) | ago (0.9441) | | 2 (0.0600) | BC (0.0893) | The (0.0083) | | 5 (0.0303) | B (0.0783) | old (0.0049) | | Step 1 | Step 2 | Step 3 | |--------|--------|--------| | 200 (0.3337) | 0 (0.3258) | years (0.6454) | | 220 (0.2545) | BC (0.2861) | B (0.1324) | | 221 (0.0614) | B (0.2307) | BC (0.0731) | #### Can we keep KV cache only in specific layers? **remove connections to KV in shallow layers (first layer)** | Step 1 | Step 2 | Step 3 | |--------|--------|--------| | 6 (0.1923) | , (0.8492) | 000 (0.4240) | | \xa0 (0.1426) | . (0.0581) | (0.1166) | | 2 (0.0588) | meters (0.0153) | 400 (0.0972) | **remove connections to KV in deeper layers (second half)** | Step 1 | Step 2 | Step 3 | |--------|--------|--------| | 500 (0.3187) | years (0.4465) | ago (0.9633) | | 2 (0.1533) | BC (0.2220) | before (0.0042) | | 220 (0.1342) | B (0.1022) | BC (0.0039) | **remove connections to KV in interleaving layers (2, 4, 6, ...)** | Step 1 | Step 2 | Step 3 | |--------|--------|--------| | 500 (0.3187) | years (0.5976) | ago (0.9515) | | 220 (0.1177) | B (0.1482) | before (0.0051) | | 200 (0.0820) | BC (0.1219) | . (0.0042) | Disable cache in shallow layers result in invalid attention to the context. Shallow layer interactions could be more significant than we expected because of the dependency along depths. #### Can KV cache be sparse without affecting the generation? | Step 1 | Step 2 | Step 3 | |--------|--------|--------| | 500 (0.2862) | years (0.6981) | ago (0.9237) | | 220 (0.2219) | B (0.0833) | before (0.0112) | | 200 (0.0765) | BC (0.0757) | , (0.0089) | Adopting a threshold $\delta=1.8$ removes more than 75% elements in $K$ without affecting the correct output. However, almost no token in the sequence is entirely ignored, as not all dimension elements are masked. Dimension and layer-dependent thresholds may be required for more aggressive sparsification. #### TODO - in-depth analysis on the first-layer attention - mean-pooling effect - cross-doc reasoning