Part of paper: ![](https://hackmd.io/_uploads/r1gvbSLc3.png) # Analysis of Different Retention Representations From the paper, it was mentioned that "A hybrid form of parallel representation and recurrent representation is available to accelerate training." This suggests that all three representations should yield exactly the same output. Let's verify this... ## Different Representations ### 1. Recurrent Representation The recurrent form can be written as: $$ S_n = \gamma S_{n-1} + K_n^\mathsf{T}V_n $$ Which expands to: $$ S_n = \sum_{i=1}^n \gamma^{n-i} K_i^\mathsf{T}V_i $$ ### 2. Chunk-wise Representation For chunk size $B$, we define: $$ \begin{aligned} K_{[i]} &= K_{B(i-1)+1:Bi} \\ V_{[i]} &= V_{B(i-1)+1:Bi} \end{aligned} $$ The chunk-wise recurrent computation: $$ R_i = K_{[i]}^\mathsf{T}V_{[i]} + \gamma^BR_{i-1}, \quad R_i \in \mathbb{R}^{d_{\sf{qk}} \times d_{\sf{v}}} $$ Which expands to: $$ \begin{aligned} R_i = &\sum_{j=1}^B K_{B(i-1)+j}^\mathsf{T}V_{B(i-1)+j} + \\ &\gamma^B \sum_{j=1}^B K_{B(i-2)+j}^\mathsf{T}V_{B(i-2)+j} + \\ &\cdots + \\ &\gamma^{B(i-1)} \sum_{j=1}^B K_{j}^\mathsf{T}V_{j} \end{aligned} $$ ## Example Analysis Let's consider a case where $n = 8$ and chunk size $B = 2$ ### Recurrent Form $$ \begin{aligned} S_8 = &K_8^\mathsf{T}V_8 + \gamma K_{7}^\mathsf{T}V_{7} + \gamma^2 K_6^\mathsf{T}V_6 + \gamma^3 K_5^\mathsf{T}V_5 + \\ &\gamma^4 K_4^\mathsf{T}V_4 + \gamma^5 K_3^\mathsf{T}V_3 + \gamma^6 K_2^\mathsf{T}V_2 + \gamma^7 K_1^\mathsf{T}V_1 \end{aligned} $$ $$ \text{Retention}(X_8) = Q_8S_8 $$ ### Chunk-wise Form For chunk $[3]$ (indices 5-6): $$ \begin{aligned} R_{[3]} = &K_6^\mathsf{T}V_6 + K_{5}^\mathsf{T}V_{5} + \\ &\gamma^2(K_4^\mathsf{T}V_4 + K_{3}^\mathsf{T}V_{3}) + \\ &\gamma^4(K_2^\mathsf{T}V_2 + K_{1}^\mathsf{T}V_{1}) \end{aligned} $$ For the final chunk $[4]$ (indices 7-8): $$ \begin{aligned} \text{Retention}(X_{[4]}) &= (Q_{[4]}K_{[4]}^\mathsf{T} \odot D)V_{[4]} + (Q_{[4]}R_{[3]})\odot\xi \\ &= [\text{Retention}(X_{7}), \text{Retention}(X_{8})] \end{aligned} $$ Where: $$ \begin{aligned} \text{Retention}(X_{7}) &= Q_7S_7, \quad S_7 = K_7^{\mathsf{T}}V_7 + R_{[3]} \odot\xi_1, \quad \xi_1 = \gamma^2 \\ \text{Retention}(X_{8}) &= Q_8S_8, \quad S_8 = K_8^{\mathsf{T}}V_8 + \gamma K_7^{\mathsf{T}}V_7 + R_{[3]} \odot\xi_2, \quad \xi_2 = \gamma^3 \end{aligned} $$ ## Fixing the Discrepancy ### 1. Modified R and ΞΎ Calculation $$ \begin{aligned} R_{[i]} &= (\xi \odot K_{[i]})^\mathsf{T}V_{[i]} + \gamma^BR_{[i-1]} \\ \xi_{ij} &= \gamma^{i-1} \\ R_{[i]} &\in \mathbb{R}^{d_{\sf{qk}} \times d_{\sf{v}}} \end{aligned} $$ For chunk $[3]$: $$ \begin{aligned} R_{[3]} &= K_6^\mathsf{T}V_6 + \gamma K_{5}^\mathsf{T}V_{5} + \\ &\gamma^2(K_4^\mathsf{T}V_4 + \gamma K_{3}^\mathsf{T}V_{3}) + \\ &\gamma^4(K_2^\mathsf{T}V_2 + \gamma K_{1}^\mathsf{T}V_{1}) = S_6 \end{aligned} $$ ### 2. Modified Retention Equation $$ \text{Retention}(X_{[i]}) = (Q_{[i]}K_{[i]}^\mathsf{T} \odot D)V_{[i]} + (Q_{[i]}R_{[i-1]})\odot\xi $$ ### 3. Final Verification For $X_7$ and $X_8$: $$ \begin{aligned} S_7 &= \sum_{j=1}^7 \gamma^{j-1} K_{8-j}^\mathsf{T}V_{8-j} \\ S_8 &= \sum_{j=1}^8 \gamma^{j-1} K_{9-j}^\mathsf{T}V_{9-j} \end{aligned} $$ ### Python Implementation ```python def chunked_retention( q, k, v, # bsz, heads, chunk_size, dim past_kv, # bsz, heads, dim, dim decay_mask, # heads, chunk_size, chunk_size chunk_decay, # heads, 1, 1 inner_decay, # heads, chunk_size, 1 current_decay, # heads, chunk_size, 1 ): retention = q @ k.transpose(-1, -2) retention = retention * decay_mask inner_retention = retention @ v cross_retention = (q @ past_kv) * inner_decay output = inner_retention + cross_retention current_kv = chunk_decay * past_kv + (k * current_decay).transpose(-1, -2) @ v return output, current_kv ``` This implementation consolidates the mathematical formulations into an efficient computation that handles both inner-chunk and cross-chunk retention patterns.