> **Background**: **a)** SnapTra's encoder uses standard transformer layers to process long sequences of length $N$ by splitting them into segments of length $n$. Each segment is appended with a special memory token to compress its information, producing encoded representations called snapshots. **b)** The decoder, consisting of several layers, updates token representations based solely on previous snapshots and in-segment normal tokens. **c)** SnapTra is initialized from a pre-trained standard transformer rather than being trained from scratch. --- > **Lemma 1**: Given that parameters in the MLP layers of a pretrained Transformer store abstract knowledge of the data [\[1\]](https://arxiv.org/abs/2012.14913), it is preferable to either leave these parameters unchanged or minimize their tuning during further training of SnapTra. Given Lemma 1, we only focus on the attention layers to approximate computation of a vanilla Transformer. ## 1 Snapshot as Context > We start the discussion from a simplified setting where only one snapshot is involved, the generation of $x_{n+1}$ is based on the snapshot of the first segment $s_{1:n}$ only. > To further simplify the setting, let's first consider how to use SnapTra to approximate a single-layer Transformer. **Standard Transformer Setting** In a standard transformer, for a sequence of n tokens, the representation of the next token $x_{n+1}$ would be: $$x^{(1)}_{n+1} = f(\sum_{i=1}^{n} \alpha(W_qx^{(0)}_{n+1}, W_kx^{(0)}_i) W_v x^{(0)}_i + \alpha(W_qx^{(0)}_{n+1}, W_kx^{(0)}_{n+1})W_vx^{(0)}_{n+1}) .$$ **SnapTra Approximation** In the encoder part, the representation of snapshot is updated with $$s^{(1)}_{1:n} = f(\sum_{i=1}^{n} \alpha(W_qs^{(0)}_{1:n}, W_k'x^{(0)}_i) W'_v x^{(0)}_i). $$ In the decoder, the updated representation $$x^{(2)}_{n+1} = f(\alpha(W_qx^{(1)}_{n+1}, W_k''s^{(1)}_{1:n}) W_v s^{(1)}_{1:n} + \alpha(W_qx^{(1)}_{n+1}, W_k''x^{(1)}_{n+1})W_vx^{(1)}_{n+1}) .$$ Ignoring the attention to $x_{n+1}$ itself, it's trivial to approximate the attention in original transformer by adjusting the initial embedding of the snapshot and tuning the query projection. It results in an aligned distribution of attention over tokens in the first segment as in the original transformer. > For approximation to multi-layer transformers, it's much more difficult with attention over the first segment varies in different layers. Let's start with the original transformer attention and the SnapTra approximation: **Original Transformer Attention** For token $x_{n+1}$ at layer l: $$a^{(l)}_{n+1,i} = \text{softmax}(\frac{(W_q^{(l)}x^{(l-1)}_{n+1})^T(W_k^{(l)}x^{(l-1)}_i)}{\sqrt{d_k}})$$ $$x^{(l)}{n+1} = f(\sum_{i=1}^n a^{(l)}_{n+1,i}W_v^{(l)}x^{(l-1)}_i)$$ Where $a^{(l)}_{n+1,i}$ is the attention weight from token n+1 to token i at layer l. **SnapTra Approximation** Encoder (creating snapshot): $$s^{(1)}{1:n} = f(\sum_{i=1}^n \alpha_i W_v' x^{(0)}_i)$$ Decoder (using snapshot): $$x^{(l)}_{n+1} = f(\beta^{(l)} W_v^{(l)} s^{(l-1)}_{1:n})$$ Where $\alpha_i$ are fixed weights determined in the encoder, and $\beta^{(l)}$ is a scalar attention weight in the decoder (For simplicity, we ignore the attention to itself again). Now, let's show why these are not equivalent: **Degrees of Freedom** Original Transformer: For each layer l and token n+1, we have n independent attention weights $a^{(l)}_{n+1,i}$. They usually have different distributions in different layers because of the distinguished functionality. SnapTra: We have only one set of fixed weights $\alpha_i$ in the encoder (the implicit distribution will not change in the decoder with attention to itself only), and one scalar $\beta^{(l)}$ per layer in the decoder. ## Why Local Context Helps Resolve this Issue 1. Original Transformer setup: In the original Transformer, for L layers, we can represent the attention weights as a 3D tensor A of shape [L, 1, N], where N is the sequence length. Each layer has its own set of N attention weights for the next token prediction. 2. SnapTra without local context: In the basic SnapTra setup we discussed earlier, we essentially have a fixed set of weights α in the encoder and a single scalar β per layer in the decoder. This can be represented as a matrix of shape [L, 1], which is clearly of much lower rank than the original attention tensor. 3. SnapTra with local context buffer: Now, let's consider adding a local context buffer of size k. The setup becomes: $$ x^{(l)}_{n+1} = f(β^{(l)} W_v^{(l)} s^{(l-1)}_{1:n} + Σ_{i=n-k+1}^n γ^{(l)}_i W_v^{(l)} x^{(l-1)}_i) $$ Where γ^(l)_i are the attention weights for the local context at layer l. 4. Matrix rank perspective: With this modification, we now have: - A fixed set of weights α for the snapshot (encoder) - A scalar β^(l) per layer for attending to the snapshot (decoder) - k attention weights γ^(l)_i per layer for the local context (decoder) We can represent this as a matrix of shape [L, k+1], where each row contains [β^(l), γ^(l)_1, ..., γ^(l)_k]. 5. Why this helps: The key insight is that this new matrix has a much higher rank than the previous [L, 1] matrix without local context. It allows for more flexibility in approximating the original attention patterns: a) The snapshot (through β^(l)) captures long-range dependencies and global information. b) The local context weights (γ^(l)_i) allow for fine-grained attention patterns within recent tokens, which can vary across layers. This combination allows SnapTra to better approximate the varying attention distributions across different layers of the original Transformer. The local context provides the necessary degrees of freedom to capture layer-specific attention patterns, while the snapshot maintains the ability to attend to distant tokens. In terms of expressiveness, this setup is closer to the original Transformer's [L, 1, N] attention tensor, especially for tokens near the prediction point. While it's still a simplification (as k < N), it strikes a balance between computational efficiency and approximation quality. <!-- ![image](https://hackmd.io/_uploads/S1A5O-9KA.png) --> ![image](https://hackmd.io/_uploads/Hy0g9-5KC.png) 1. Information Perspective Without Local Context Without local context, the snapshot $s_{1:n}$ needs to capture all relevant information from the entire preceding sequence $x_{1:n}$ for predicting $x_{n+1}$. In information-theoretic terms, we're trying to maximize: $$I(x_{n+1}; s_{1:n}) \approx I(x_{n+1}; X_{1:n})$$ Where $I(\cdot;\cdot)$ denotes mutual information. 2. Information Perspective with Local Context When we introduce a local context buffer of $k$ tokens, we're now working with two sources of information: a) The snapshot $s_{1:n}$ b) The local context $x_{n-k+1:n}$ Our prediction now relies on both: $$I(x_{n+1}; s_{1:n}, X_{n-k+1:n})$$ 3. Delta Mutual Information The key insight here is that with local context involved, the snapshot no longer needs to capture all the information - it only needs to capture the information not present in the local context. This is where the concept of delta mutual information comes in. Delta mutual information can be expressed as: $$\Delta I = I(x_{n+1}; s_{1:n} | X_{n-k+1:n})$$ This represents the additional information that the snapshot provides about $X_{n+1}$, given that we already have the local context. 4. Implications for Snapshot Modeling Now, the snapshot's role is to model this delta mutual information. In other words, it needs to capture: - Long-range dependencies - Global information - Any other relevant information not contained in the recent $k$ tokens This is a more efficient use of the snapshot's capacity. It doesn't need to waste capacity on information that's already available in the local context. 5. Information Compression From this perspective, we can view the snapshot as performing a kind of adaptive compression. It's compressing the information from $x_{1:n}$ in a way that complements the information available in the local context. 6. Efficiency Gains This approach is more efficient because: a) It reduces redundancy between the snapshot and local context b) It allows the model to focus on capturing truly long-range dependencies in the snapshot c) It enables more precise modeling of recent information through the local context 7. Balancing Act The choice of $k$ (size of local context) becomes crucial. A larger $k$ reduces the burden on the snapshot but increases computation. The optimal $k$ would balance: - The amount of recent information needed for accurate prediction - The computational cost of processing the local context - The capacity of the snapshot to capture long-range dependencies In essence, by involving local context, we're allowing the snapshot to specialize in capturing delta mutual information - the additional, complementary information not present in recent tokens. This makes the overall system more efficient and potentially more powerful in modeling both short-term and long-term dependencies in the sequence. --- 1. Segment Coherence: For the compressed architecture to work effectively, we ideally want segments to be coherent units of information. This implies: $I(x_i; x_j | x_k) \approx 0$, for $i,j \in S_m$, $k \in S_n$, $m \neq n$ Where $I(x_i; x_j | x_k)$ is the conditional mutual information between tokens $x_i$ and $x_j$ given $x_k$, and $S_m$, $S_n$ are different segments. This condition suggests that tokens within a segment should be more strongly related to each other than to tokens in other segments, given the information in other segments. 2. Snapshot Representativeness: The snapshot of each segment should be a good summary of the information in that segment. We can express this as: $H(S_k | s_k) \ll H(S_k)$ Where $H(S_k)$ is the entropy of segment $k$, and $H(S_k | s_k)$ is the conditional entropy of the segment given its snapshot. This implies that the snapshot should capture most of the relevant information in the segment. 3. Inter-segment Information Flow: While we want segments to be relatively self-contained, we still need to allow for information flow between segments. This can be expressed as: $I(S_i; S_j | \{s_k\}_{k \neq i,j}) > 0$ This condition states that there should be some mutual information between different segments, even after conditioning on all other snapshots. This ensures that the model can still capture long-range dependencies. 4. Segment Size Distribution: The distribution of segment sizes should be relatively uniform to ensure balanced compression: $\text{Var}(|S_k|) \ll \mathbb{E}[|S_k|]^2$ Where $|S_k|$ is the size of segment $k$. This condition helps in maintaining consistent information density across snapshots. 5. Temporal or Spatial Locality: For many types of data (e.g., text, time series), there's often a natural temporal or spatial locality. We might express this as: $I(x_i; x_j) \approx f(|i-j|)$ Where $f$ is a decreasing function. This property can guide the segmentation process, suggesting that nearby tokens are more likely to belong to the same segment. 6. Cross-segment Attention Sparsity: For the compressed architecture to be computationally efficient, we want the cross-segment attention to be relatively sparse: $\sum_k \mathbb{1}(\beta_{ik} > \epsilon) \ll K$ Where $K$ is the total number of segments, $\epsilon$ is a small threshold, and $\mathbb{1}$ is the indicator function. This condition suggests that each token should attend strongly to only a small number of snapshots. These requirements on the data and segment relationships provide a framework for designing effective segmentation strategies and for understanding when this compressed architecture is likely to perform well. They also highlight the trade-offs between computational efficiency and model expressiveness. Would you like to explore any of these conditions further, or discuss how they might be implemented or tested in practice? ### Overlap for Independence That's an astute observation. Indeed, if we can't ensure complete independence among segments, introducing some overlap between segments can be a helpful strategy. Let's explore this idea: Overlapping Segments: Instead of having strictly disjoint segments, we can define overlapping segments: $S_k \cap S_{k+1} \neq \emptyset$ Where $S_k$ and $S_{k+1}$ are adjacent segments. Modified Segment Coherence Condition: We can relax our original condition to account for this overlap: $I(x_i; x_j | x_k, O_{m,n}) < \epsilon$ Where: $i \in S_m \setminus O_{m,n}$ $j \in S_n \setminus O_{m,n}$ $k \in O_{m,n}$ $O_{m,n} = S_m \cap S_n$ is the overlap between segments $m$ and $n$ $\epsilon$ is a small threshold This condition states that the mutual information between non-overlapping parts of different segments should be small, given the information in the overlap. Overlap Size: The size of the overlap can be defined as a fraction of the segment size: $|O_{k,k+1}| = \alpha |S_k|$ Where $0 < \alpha < 1$ is the overlap fraction. Information Retention in Snapshots: With overlapping segments, we need to ensure that snapshots retain information from the overlaps: $I(s_k; O_{k-1,k}, O_{k,k+1}) \approx H(O_{k-1,k}, O_{k,k+1})$ This condition ensures that the snapshot $s_k$ captures most of the information in the overlaps with adjacent segments. Attention to Overlapping Regions: We can modify our attention mechanism to give special consideration to overlapping regions: $\beta_{ik} = f(\beta'{ik}, \gamma{i,O_{k-1,k}}, \gamma_{i,O_{k,k+1}})$ Where: $\beta'_{ik}$ is the base attention to the snapshot $\gamma_{i,O}$ is an additional attention term for the overlap region $O$ $f$ is a function that combines these attention components Continuity Across Segments: Overlapping segments can help maintain continuity of information flow. We can express this as: $|\phi(x_i) - \phi(x_j)| < \delta$ for $i \in S_k, j \in S_{k+1}, |i-j| < |O_{k,k+1}|$ Where $\phi(x)$ is some representation of token $x$, and $\delta$ is a small threshold. This condition ensures smooth transitions between segments. Benefits of Overlapping Segments: Improved information flow between segments Reduced loss of context at segment boundaries More robust to suboptimal segmentation Potential for better handling of long-range dependencies Challenges: Increased computational cost due to redundant processing of overlapping regions More complex attention mechanisms to handle overlaps Potential for information duplication in snapshots By introducing overlaps between segments, we can create a more robust model that better approximates the full attention mechanism of a standard transformer, especially when strict independence between segments cannot be guaranteed. The key is to balance the size of the overlap against computational costs and to design appropriate mechanisms for integrating information from overlapping regions. ## References [1] Transformer Feed-Forward Layers Are Key-Value Memories