Part of paper:

# Analysis of Different Retention Representations
From the paper, it was mentioned that "A hybrid form of parallel representation and recurrent representation is available to accelerate training." This suggests that all three representations should yield exactly the same output. Let's verify this...
## Different Representations
### 1. Recurrent Representation
The recurrent form can be written as:
$$
S_n = \gamma S_{n-1} + K_n^\mathsf{T}V_n
$$
Which expands to:
$$
S_n = \sum_{i=1}^n \gamma^{n-i} K_i^\mathsf{T}V_i
$$
### 2. Chunk-wise Representation
For chunk size $B$, we define:
$$
\begin{aligned}
K_{[i]} &= K_{B(i-1)+1:Bi} \\
V_{[i]} &= V_{B(i-1)+1:Bi}
\end{aligned}
$$
The chunk-wise recurrent computation:
$$
R_i = K_{[i]}^\mathsf{T}V_{[i]} + \gamma^BR_{i-1}, \quad R_i \in \mathbb{R}^{d_{\sf{qk}} \times d_{\sf{v}}}
$$
Which expands to:
$$
\begin{aligned}
R_i = &\sum_{j=1}^B K_{B(i-1)+j}^\mathsf{T}V_{B(i-1)+j} + \\
&\gamma^B \sum_{j=1}^B K_{B(i-2)+j}^\mathsf{T}V_{B(i-2)+j} + \\
&\cdots + \\
&\gamma^{B(i-1)} \sum_{j=1}^B K_{j}^\mathsf{T}V_{j}
\end{aligned}
$$
## Example Analysis
Let's consider a case where $n = 8$ and chunk size $B = 2$
### Recurrent Form
$$
\begin{aligned}
S_8 = &K_8^\mathsf{T}V_8 + \gamma K_{7}^\mathsf{T}V_{7} + \gamma^2 K_6^\mathsf{T}V_6 + \gamma^3 K_5^\mathsf{T}V_5 + \\
&\gamma^4 K_4^\mathsf{T}V_4 + \gamma^5 K_3^\mathsf{T}V_3 + \gamma^6 K_2^\mathsf{T}V_2 + \gamma^7 K_1^\mathsf{T}V_1
\end{aligned}
$$
$$
\text{Retention}(X_8) = Q_8S_8
$$
### Chunk-wise Form
For chunk $[3]$ (indices 5-6):
$$
\begin{aligned}
R_{[3]} = &K_6^\mathsf{T}V_6 + K_{5}^\mathsf{T}V_{5} + \\
&\gamma^2(K_4^\mathsf{T}V_4 + K_{3}^\mathsf{T}V_{3}) + \\
&\gamma^4(K_2^\mathsf{T}V_2 + K_{1}^\mathsf{T}V_{1})
\end{aligned}
$$
For the final chunk $[4]$ (indices 7-8):
$$
\begin{aligned}
\text{Retention}(X_{[4]}) &= (Q_{[4]}K_{[4]}^\mathsf{T} \odot D)V_{[4]} + (Q_{[4]}R_{[3]})\odot\xi \\
&= [\text{Retention}(X_{7}), \text{Retention}(X_{8})]
\end{aligned}
$$
Where:
$$
\begin{aligned}
\text{Retention}(X_{7}) &= Q_7S_7, \quad S_7 = K_7^{\mathsf{T}}V_7 + R_{[3]} \odot\xi_1, \quad \xi_1 = \gamma^2 \\
\text{Retention}(X_{8}) &= Q_8S_8, \quad S_8 = K_8^{\mathsf{T}}V_8 + \gamma K_7^{\mathsf{T}}V_7 + R_{[3]} \odot\xi_2, \quad \xi_2 = \gamma^3
\end{aligned}
$$
## Fixing the Discrepancy
### 1. Modified R and ΞΎ Calculation
$$
\begin{aligned}
R_{[i]} &= (\xi \odot K_{[i]})^\mathsf{T}V_{[i]} + \gamma^BR_{[i-1]} \\
\xi_{ij} &= \gamma^{i-1} \\
R_{[i]} &\in \mathbb{R}^{d_{\sf{qk}} \times d_{\sf{v}}}
\end{aligned}
$$
For chunk $[3]$:
$$
\begin{aligned}
R_{[3]} &= K_6^\mathsf{T}V_6 + \gamma K_{5}^\mathsf{T}V_{5} + \\
&\gamma^2(K_4^\mathsf{T}V_4 + \gamma K_{3}^\mathsf{T}V_{3}) + \\
&\gamma^4(K_2^\mathsf{T}V_2 + \gamma K_{1}^\mathsf{T}V_{1}) = S_6
\end{aligned}
$$
### 2. Modified Retention Equation
$$
\text{Retention}(X_{[i]}) = (Q_{[i]}K_{[i]}^\mathsf{T} \odot D)V_{[i]} + (Q_{[i]}R_{[i-1]})\odot\xi
$$
### 3. Final Verification
For $X_7$ and $X_8$:
$$
\begin{aligned}
S_7 &= \sum_{j=1}^7 \gamma^{j-1} K_{8-j}^\mathsf{T}V_{8-j} \\
S_8 &= \sum_{j=1}^8 \gamma^{j-1} K_{9-j}^\mathsf{T}V_{9-j}
\end{aligned}
$$
### Python Implementation
```python
def chunked_retention(
q, k, v, # bsz, heads, chunk_size, dim
past_kv, # bsz, heads, dim, dim
decay_mask, # heads, chunk_size, chunk_size
chunk_decay, # heads, 1, 1
inner_decay, # heads, chunk_size, 1
current_decay, # heads, chunk_size, 1
):
retention = q @ k.transpose(-1, -2)
retention = retention * decay_mask
inner_retention = retention @ v
cross_retention = (q @ past_kv) * inner_decay
output = inner_retention + cross_retention
current_kv = chunk_decay * past_kv + (k * current_decay).transpose(-1, -2) @ v
return output, current_kv
```
This implementation consolidates the mathematical formulations into an efficient computation that handles both inner-chunk and cross-chunk retention patterns.