# Mamba & RWKV
Authors: Weimin Wu, Xinran Li, Xinze Liu, Wenxin Zhang
## Table of Content
- **RWKV**
- Motivation for RWKV
- Architecture of RNN and Transformer
- Architecture of RWKV
- Parallel Method for RWKV
- Experimental Performance of RWKV
- **Mamba**
- Introduction to SSM
- Selective Structured Selective State Space Model (S6)
- Architecture of Mamba
- Hardware-aware State of Mamba
- Experimental Performance of Mamba
- **Comparision between RWKV and Mamba**
**We refer to the following**:
[1]. Gu, Albert, et al. "Hippo: Recurrent memory with optimal polynomial projections." Advances in neural information processing systems 33 (2020): 1474-1487.
[2]. Gu, Albert, Karan Goel, and Christopher Ré. "Efficiently modeling long sequences with structured state spaces." arXiv preprint arXiv:2111.00396 (2021).
[3]. Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A., & Ré, C. (2021). Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34, 572-585.
[4]. Gu, Albert, et al. "How to train your hippo: State space models with generalized orthogonal basis projections." arXiv preprint arXiv:2206.12037 (2022).
[5]. Gupta, Ankit, Albert Gu, and Jonathan Berant. "Diagonal state spaces are as effective as structured state spaces." Advances in Neural Information Processing Systems 35 (2022): 22982-22994.
[6]. Fu, Daniel Y., et al. "Hungry hungry hippos: Towards language modeling with state space models." arXiv preprint arXiv:2212.14052 (2022).
[7]. Gu, Albert, and Tri Dao. "Mamba: Linear-time sequence modeling with selective state spaces." arXiv preprint arXiv:2312.00752 (2023).
Lieber, Opher, et al. "Jamba: A hybrid transformer-mamba language model." arXiv preprint arXiv:2403.19887 (2024).
[8]. Jamil, U. (2024, January 7). Mamba and S4 explained: Architecture, parallel scan, kernel fusion, recurrent, convolution, math. YouTube. https://www.youtube.com/watch?v=8Q_tqwpTpVU.
[9].Gu, Albert, et al. "The Annotated S4", [annotated S4 website]( https://srush.github.io/annotated-s4/#part-2-implementing-s4) (2023)
[10]. Preetham, Freedom, Comprehensive Breakdown of Selective Structured State Space Model — Mamba (S5),medium, [Comprehensive breakdown of Mamba](https://medium.com/autonomous-agents/comprehensive-breakdown-of-selective-structured-state-space-model-mamba-s5-441e8b94ecaf), 2024
[11]. Maarten Grootendorst, [A Visual Guide to Mamba and State Space Models](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state), 2024
[12]. Danielle Ensign, SrGonao, Adrià Garriga-alonso, [Ophiology (or, how the Mamba architecture works)](https://www.lesswrong.com/posts/TYLQ8gAMAmpeFcwXN/ophiology-or-how-the-mamba-architecture-works), 2024
[13]. Peng, Bo, et al. "Rwkv: Reinventing rnns for the transformer era." arXiv preprint arXiv:2305.13048 (2023).
[14]. Olah, Christopher. "Understanding lstm networks." (2015).
[15]. Johan Sokrates Wind, ["How the RWKV language model works."](https://johanwind.github.io/2023/03/23/rwkv_details.html) (2023).
[16]. Peng, Bo, et al. "Eagle and Finch: RWKV with matrix-valued states and dynamic recurrence." arXiv preprint arXiv:2404.05892 (2024).
[17]. Grazzi, Riccardo, et al. "Is Mamba Capable of In-Context Learning?." arXiv preprint arXiv:2402.03170 (2024).
## RWKV
### 1. Motivation for RWKV
The RWKV model presents a novel architecture that combines the strengths of RNNs and Transformers, addressing their respective limitations. This model achieves linear computational and memory complexity during inference, a significant improvement over the quadratic complexity of standard Transformers. RWKV incorporates a linear attention mechanism, allowing it to function as both a Transformer and a RNN, thus harnessing the benefits of parallelizable training and efficient inference.
### 2. Architecture of RNN and Transformer
Here, we first give a review about RNN and Transformer.
#### (i). RNN
The following content refers to [14].
The most recognized RNN architectures are the LSTM and GRU.
- #### Limitation:
RNNs inherently depend on previous time steps, which restricts parallelization of these typical networks.
- LSTM

The first step in our LSTM is to decide what information we’re going to throw away from the cell state. This decision is made by a sigmoid layer called the “forget gate layer.” It looks at $h_{t−1}$ and $x_t$, and outputs a number between 0 and 1 for each number in the cell state $C_{t−1}$. A 1 represents “completely keep this” while a 0 represents “completely get rid of this.”
Language model example: when language model is trying to predict the next word based on all the previous ones. In such a problem, the cell state might include the gender of the present subject, so that the correct pronouns can be used. When we see a new subject, we want to forget the gender of the old subject.

The next step is to decide what new information we’re going to store in the cell state. This has two parts. First, a sigmoid layer called the “input gate layer” decides which values we’ll update. Next, a $tanh$ layer creates a vector of new candidate values, $\tilde{C}_{t}$ , that could be added to the state. In the next step, we’ll combine these two to create an update to the state.
Language model: we’d want to add the gender of the new subject to the cell state, to replace the old one we’re forgetting.

It’s now time to update the old cell state, $C_{t−1}$, into the new cell state $C_t$. The previous steps already decided what to do, we just need to actually do it.
We multiply the old state by $f_t$, forgetting the things we decided to forget earlier. Then we add $i_t∗ \tilde{C}_t$. This is the new candidate values, scaled by how much we decided to update each state value.
Language model example: this is where we’d actually drop the information about the old subject’s gender and add the new information, as we decided in the previous steps.

Finally, we need to decide what we’re going to output. This output will be based on our cell state, but will be a filtered version. First, we run a sigmoid layer which decides what parts of the cell state we’re going to output. Then, we put the cell state through $tanh$ (to push the values to be between −1 and 1) and multiply it by the output of the sigmoid gate, so that we only output the parts we decided to.
For the language model example, since it just saw a subject, it might want to output information relevant to a verb, in case that’s what is coming next. For example, it might output whether the subject is singular or plural, so that we know what form a verb should be conjugated into if that’s what follows next.

For more information about LSTM, refer to [colah's blog](https://https://colah.github.io/posts/2015-08-Understanding-LSTMs/).
- #### GRU

For more information about LSTM vs GRU performance, please refer to [here](https://https://medium.com/mindboard/lstm-vs-gru-experimental-comparison-955820c21e8b).
#### (ii). Transformer
The following content refers to [13].
Transformers have become a primary architecture. Instead of operating on sequences step-by-step like RNNs, Transformers use attention mechanisms to capture relationships between all input and all output tokens:
$$
\operatorname{Attn}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{\top}}{\sqrt{d_k}}\right) V.
$$
The core $Q K^{\top}$ multiplication represents pairwise attention scores across all tokens, which can be decomposed as vector operations:
$$
\operatorname{Attn}(Q, K, V)_t=\frac{\sum_{i=1}^T e^{q_t^{\top} k_i} \odot v_i}{\sum_{i=1}^T e^{q_t^{\top} k_i}} .
$$
Attention Free Transformer(AFT), alternately formulates
$$
\operatorname{Attn}^{+}(W, K, V)_t=\frac{\sum_{i=1}^t e^{w_{t,~ i}+k_i} \odot v_i}{\sum_{i=1}^t e^{w_{t, i}+k_i}},
$$
where $\left\{w_{t, i}\right\} \in R^{T \times T}$ is the learned pair-wise position biases, and each $w_{t, i}$ is a scalar. For instance, an element in the matrix $w$ at position (1,3) indicates the weight with which position 1 should pay attention to position 3.
Inspired by AFT, RWKV takes a similar approach by modifying the interaction weights so that it can be transformed into an RNN. Each $w_{t, i}$ in RWKV is now a channelwise time decay vector multiplied by the relative position and traced backward from current time as it decays:
$$
w_{t, i}=-(t-i) w,
$$
where $w \in\left(R_{\geq 0}\right)^d$, with $d$ the number of channels. We require $w$ to be non-negative to ensure that $e^{w_{t, i}} \leq 1$ and the per-channel weights decay backwards in time.
### 3. Architecture of RWKV
The following content refers to [13].
Based on the RNN and Transformer, we introduce the RWKV here.
RWKV includes four parts:
- Receptance vector $R$ : This vector is pivotal in receiving and integrating information from past data inputs, functioning as a memory conduit in the system.
- Weight vector $W$ : Acting as a positional weight decay vector, $W$ is a trainable parameter that adjusts how much past information influences current data processing, thereby managing the retention or decay of information over time.
- Key vector $K$ : Similar to its role in traditional attention mechanisms, the Key vector helps in aligning and comparing different parts of the input data, assisting in determining the relevance of various data points.
- Value vector $V$ : The Value vector is crucial for encoding the actual information of the input data, which will be used for further processing and decision-making in the network.
The RWKV model is composed of stacked residual blocks. Each block consists of a time-mixing and a channel-mixing sub-block, embodying recurrent structures to leverage past information. The time-mixing is similar to the attention block in Transformer, and the channel-mixing is similar to the feed-forward block in Transformer.

The RWKV includes three techniques: Token shift, WKV operator and Output gating.
#### Token Shift
Token shifting is a novel technique in RWKV that enhances model flexibility and responsiveness. By interpolating between the current and previous timestep inputs, all vectors involved in the computation ($R, K, V$ for time-mixing and $R', K'$ for channel-mixing) are adjusted dynamically. This method ensures that the model smoothly transitions and integrates changes over sequential data inputs, enhancing its predictive accuracy and contextual understanding. The equations for this process are:
$$
\begin{aligned}
r_t & =W_r \cdot\left(\mu_r \odot x_t+\left(1-\mu_r\right) \odot x_{t-1}\right), \\
k_t & =W_k \cdot\left(\mu_k \odot x_t+\left(1-\mu_k\right) \odot x_{t-1}\right), \\
v_t & =W_v \cdot\left(\mu_v \odot x_t+\left(1-\mu_v\right) \odot x_{t-1}\right),
\end{aligned}
$$
as are the channel-mixing inputs:
$$
\begin{aligned}
r_t^{\prime} & =W_r^{\prime} \cdot\left(\mu_r^{\prime} \odot x_t+\left(1-\mu_r^{\prime}\right) \odot x_{t-1}\right), \\
k_t^{\prime} & =W_k^{\prime} \cdot\left(\mu_k^{\prime} \odot x_t+\left(1-\mu_k^{\prime}\right) \odot x_{t-1}\right) .
\end{aligned}
$$

#### WKV Operator
The RWKV model employs a unique computational approach called the WKV operator, which parallels the method used in Attention-Free Transformers (AFT). However, unlike the AFT where weights are applied pairwise, RWKV treats the weight as a channel-wise vector that adjusts based on relative positions. This recurrent behavior is crucial for the model as it updates the vectors over time to better capture and represent temporal dynamics. The formal computation is given by:
$$
w k v_t=\frac{\sum_{i=1}^{t-1} e^{-(t-1-i) w+k_i} \odot v_i+e^{u+k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i) w+k_i}+e^{u+k_t}} .
$$
The vector $U$ ensures that the model not only accounts for historical data but also emphasizes the current input, making it robust against information decay over long sequences.
#### Output Gating
To refine the output further, RWKV incorporates an output gating mechanism using the sigmoid function of the receptance vector. This gating mechanism allows the network to control the flow of information effectively, ensuring that only relevant and processed data contributes to the final output. For both time-mixing and channel-mixing blocks, this gating strategy regulates the contribution of each component, maintaining a balance between preserving important historical information and adapting to new inputs. The output vectors are computed as follows:
$$
o_t=W_o \cdot\left(\sigma\left(r_t\right) \odot w k v_t\right) .
$$
In the channel-mixing block, a similar operation is performed:
$$
o_t^{\prime}=\sigma\left(r_t^{\prime}\right) \odot\left(W_v^{\prime} \cdot \max \left(k_t^{\prime}, 0\right)^2\right),
$$
where we adopt the squared ReLU activation function.
### 4. Parallel Method for RWKV
The following content refers to Johan's blog [15].
With the gift of innovative architecture, parallelized computation is largely possible for RWKV. Why do we say “largely” instead of “fully” parallelizable? You will see it later.
To better understand the computation in RWKV, it is helpful to combine the formula with the code. The following snippet of code represents the minimal implementation of a relatively small RWKV model (430 million parameters) that generates text, refer to [RWKV in 150 lines](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py).
#### Channel Mixing:
We can first examine the computation during the channel-mixing, which is easier to understand.
```python
def channel_mixing(x, last_x, mix_k, mix_r, Wk, Wr, Wv):
# Wk - (4096, 1024)
# Wr - (1024, 1024)
# Wv - (1024, 4096)
k = Wk @ ( x * mix_k + last_x * (1 - mix_k) ) # (4096, 1024) @ (1024, ) -> (4096, )
r = Wr @ ( x * mix_r + last_x * (1 - mix_r) ) # (1024, 1024) @ (1024, ) -> (1024, )
vk = Wv @ np.maximum(k, 0)**2 # (1024, 4096) @ (4096, ) -> (1024, )
return sigmoid(r) * vk, x
```
Benefiting from the token-shif mechanism, we only need to have access to the previous token during channel-mixing, while the entire sequence is given during training (This only happens during training, since we need to process tokens one by one during inference). This feature makes a significant difference from traditional RNNs, which have to wait for the computation of previous states and can only process sequentially. As a result, we can calculate these independently, enabling parallel computation.
#### Time Mixing:
Now, let's look at the time-mixing part in RWKV.
```python
def time_mixing(x, last_x, last_num, last_den, decay, bonus, mix_k, mix_v, mix_r, Wk, Wv, Wr, Wout):
# Wk - (1024, 1024)
# Wr - (1024, 1024)
# Wv - (1024, 1024)
# Wout - (1024, 1024)
k = Wk @ ( x * mix_k + last_x * (1 - mix_k) ) # (1024, 1024) @ (1024, ) -> (1024, )
v = Wv @ ( x * mix_v + last_x * (1 - mix_v) ) # (1024, 1024) @ (1024, ) -> (1024, )
r = Wr @ ( x * mix_r + last_x * (1 - mix_r) ) # (1024, 1024) @ (1024, ) -> (1024, )
wkv = (last_num + exp(bonus + k) * v) / \
(last_den + exp(bonus + k)) # num here represents the numerator and den represents the denominator
rwkv = sigmoid(r) * wkv
num = exp(-exp(decay)) * last_num + exp(k) * v
den = exp(-exp(decay)) * last_den + exp(k)
return Wout @ rwkv, (x,num,den)
```
Similarly to channel-mixing, we only need previous token to calculate the $k$, $v$, and $r$ values. Therefore, the first three lines of the code code above are parallelizable. However, you may notice that we need `last_num` and `last_den` from the previous state when we calculate the $wkv$ attention. Yes, this is why we say “largely” instead of “fully” parallelizable at the beginning of this section.
Recalling the formula of wkv attention, one can find that we need $k$ and $v$ values from the previous states, which implies the inherently sequential nature of RWKV. But the authors play a trick here. In the calculation of $wkv$ (also `rwkv` from the above code), no matrix multiplication is involved, instead, they use point-wise multiplication which reduces the computational complexity by approximately a factor of 1024, while all the heavy matrix multiplications ($k$, $v$, $r$) are parallelizable. These properties combine to make the computation of RWKV faster than traditional RNNs. For a more specific code-based implementation, please refer to [this](https://github.com/BlinkDL/RWKV-CUDA).
### 5. Experimental Performance of RWKV
The following in this part refers to [13].
In this part, we introduce some experiment results about RWKV, including the text generating time and memory usage during inference, scaling law, and performance on zero-shot tasks.
### 5.1 Inference Results
First, we examine the inference time of RWKV in comparison to other models. As previously discussed, RWKV reformulates the attention mechanism using a variant of linear attention. The paper benchmarks the inference requirements based on model size and family, specifically evaluating text generation speed and memory usage on typical compute platforms, including CPU (x86) and GPU (NVIDIA A100 80 GB).
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/rJ8xZYQVR.png" width="400" height="300" alt="RWKV Inference">
</p>
The "Cumulative Time vs. Number of Tokens" graph shows that, unlike traditional transformers, RWKV exhibits linear scaling for cumulative times in text generation.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/BJXjUjm40.png" alt="RWKV Memory Usage">
</p>
The "Peak Memory(GB) vs. Number of Parameters" graph shows that RWKV uses less CPU memory compared to other models.
### 5.2 Scaling Law
Scaling laws in language models describe how performance changes with factors like model size, dataset size , or compute budget. These laws are crucial for predicting and planning costs and performance of large models before training and identifying areas for future research.
Previous work suggested that LSTMs do not follow the same log-log linear scaling as transformers. The paper trained 45 RWKV models across various datasets and parameters and found that RWKV adheres to the general scaling law form established for transformers.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/B1v31tQ40.png" width="400" height="300" alt="RWKV scaling">
</p>
The figure "Scaling laws curves for RWKV models" shows the results for loss as a function of compute, with the linear fit to the Pareto optimal points holding an $r^2$ value of 0.994. Even when extrapolating an additional order of magnitude (blue), the fit remains excellent with an $r^2$ of 0.875.
### 5.3 Language Modeling Evaluation
Having demonstrated the scalability of RWKV models in the previous section, we now evaluate their competitiveness with traditional transformers. We focus on two key questions:
1. **Competitiveness:** Is RWKV competitive against quadratic transformer architectures with the same compute resources?
2. **Long Context:** Does increasing the context length of RWKV improve language modeling loss, particularly for lengths that most open-sourced quadratic transformers cannot efficiently handle?
#### 5.3.1 Competitiveness with Traditional Transformers
To demonstrate RWKV's competitiveness in NLP tasks, the paper compares it with similarly sized models trained on a similar number of tokens, including Pythia, OPT, and BLOOM. All RWKV models were trained for one epoch on the Pile (330B tokens), which is close but not identical to the training data for Pythia, OPT, and BLOOM models.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/HkOnlt7E0.png" alt="RWKV Language Modeling">
</p>
The results on the following benchmarks are shown above: ARC (Easy and Challenge), BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, and Winogrande. With more compute (exaFLOP), RWKV performs better in 4 out of 6 benchmarks, excelling particularly in ARC (Challenge).
#### 5.3.2 Long Context - Long Range Arena
The paper evaluate our model's ability to handle very long sequences by comparing it to state-of-the-art long sequence models on the Long Range Arena (LRA) benchmark. LRA is designed to assess model performance in handling lengthy contexts, including tasks with sequences ranging from 1,000 to 16,000 tokens, covering various data types like text, natural language, synthetic images, and mathematical expressions.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/HkLuZtQEA.png" alt="RWKV Long Range Arena Evaluation">
</p>
The results show that RWKV performs second only to the S4 model on five datasets.
## Mamba
Mamba is an advanced model based on state space model (SSM). We first introduce SSM briefly. Based on the SSM, we then introduce the Selective Structured Selective State Space Model (S6) model here. After that, we introduce the architecture, hardware-aware state and experimental performance of Mamba.
### 1. Introduction to SSM
For a long time, the Transformer model is loved because of its attention mechanism and high performance. However, the problem in Transformer is that it does not scale well as the sequence length gets longer. In a transformer, every token focuses on every other tokens when making a prediction. To address this, selective state models (SSM) is proposed to achieve linear-time training seen in other models such as Recurrent Neural Networks (RNN). Next, we give a detailed comparision between these three models: Transformer, RNN and SSM.
### 1.1 SSM Framework
SSM is based on the linear time-invariant (LTI) system. We first give the expressions of SSM for single feature, then we extend it to multiple features.
**Single feature**:
We use $x(t) \in \mathbb{R}$ to denote the input data, $y(t) \in \mathbb{R}$ to denote the output data. We use four learnable matrices: $A, B, C, D$:
- $A \in \mathbb{R}^{N \times N}$ is the state matrix (controlling the latent state),
- $B \in \mathbb{R}^{N \times 1}$,
- $C \in \mathbb{R}^{1 \times N}$ is the output matrix,
- $D \in \mathbb{R}^{1 \times 1}$.
Then we have:
$$
\begin{aligned}
h^\prime(t) & =A h(t)+B x(t), \\
y(t) & =C h(t)+D x(t),
\end{aligned}
$$
where $h(t) \in \mathbb{R}^{N \times 1}$ is the latent state.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/rykcOfN4R.png" width="300" height="200">
</p>
**Multi features**:
For F-dimiensal features, we repeat the above equations for F times.
### 1.2 Discretization of SSM
The following content refers to [12].
To find the output signal $y(t)$ at time $t$, we first need to find a function $h(t)$ that describes the state of the system for all time steps. But that can be hard to solve analytically. In practice, we hardly ever work with continuous signals in machine learning, since any continuous signal from the real world would be sampled and stored as discrete data in a computer. Thus, in order to produce $y(t)$ for a discrete signal, the first step is to discretize the system itself and find the approximate solution of a differential equation means to find a sequence of $h(0)$, $h(1)$, $h(2)$, $h(3)$, etc. that describe the evolution of our system over time. So instead of finding $h(t)$ we want to find $h(t_k) = h(k\Delta)$ where $\Delta$ is our step size.
We first discretize the state space models to calculate the evolution of the state over time by using a recurrent formulation.
- By using the definition of derivative we know that: $h(t+\Delta) \approx \Delta h'(t) + h(t).$
- This is the continuous state space model: $h'(t) = Ah(t) + Bx(t).$
- We can substitute the state space model into the first expression to get the following
$$
\begin{align*}
h(t+\Delta) &\approx \Delta(Ah(t) + Bx(t)) + h(t) \\
&= (I + \Delta A) h(t) + \Delta Bx(t) \\
&= \bar{A}h(t) + \bar{B}x(t).
\end{align*}
$$
Thus we arrive at the discretization of $A$ and $B$ using Euler's Method. This discretization allows us to calculate the state of the system one step at a time, knowing the state at the previous time step. The matrices $\bar{A}$, $\bar{B}$ are discretized parameters of the model.
Thus, we can go from
$$
\begin{align*}
h'(t) &= Ah(t) + Bx(t) \\
y(t) &= Ch(t) + Dx(t).
\end{align*}
$$
to
$$
\begin{align*}
h_t &= \bar{A}h_{t-1} + \bar{B}x_t \\
y_t &= Ch_t + Dx_t.
\end{align*}
$$
In practice, we can treat $Dx_t$ as a residual block, and for simplicity, we omit $Dx_t$ from the subsequent analysis.
$$
\begin{align*}
h_t &= \bar{A}h_{t-1} + \bar{B}x_t \\
y_t &= Ch_t.
\end{align*}
$$
**Zero Order Hold rule**
As mentioned in [A Visual Guide to Mamba and State Space Models](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state), we discretize the system using the Zero-Order Hold (ZOH) rule instead of the Euler's method. ZOH means that we assume $x(t)$ is constant between $[t, t+\Delta]$.

Note that during the interval $\Delta$, the input remains the same at that time step. So from time $t$, the solution at time $t+1$ is the following equation
$$h(t+1) = e^{A((t+\Delta)-t)}h(t) + \int_{t}^{t+\Delta} e^{A((t+\Delta)-\tau)}Bd\tau.$$
Then we get
$$h(t+\Delta) = \overline{A}h(t) + \overline{B}x(t), $$
where $$\overline{A} = e^{A((t+\Delta)-t)} = e^{\Delta A },$$
$$\overline{B} = (\Delta A)^{-1} (e^{\Delta A}-I) \Delta B.$$
For a more detailed approach, please refer this [lecture](https://faculty.washington.edu/chx/teaching/me547/1-8_zohSS_slides.pdf).
### 1.3 Structure of A
To decrease the computational cost, S4 use the diagonal plus low-rank structure of matrix A. It can apply to any matrix $A$ that can be decomposed as Normal Plus Low-Rank (NPLR).
To further simplify the structure of $A$, Diagonal State Space(DSS) use the **fully-diagonal** parameterization of state spaces.
In Mamba, we use the simplified fully diagonal strcture.
### 2. Selective Structured Selective State Space Model (S6)
The following content refers to [7].
We introduce the Selective Structured Selective State Space Model (S6) model here.
### 2.1 Motivation
A fundamental problem of sequence modeling is compressing context into a smaller state. In fact, we can view the tradeoffs of popular sequence models from this point of view. For example, attention is both effective and inefficient because it explicitly does not compress context at all. This can be seen from the fact that autoregressive inference requires explicitly storing the entire context, which directly causes the slow linear-time inference and quadratic-time training of Transformers. On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time training. However, their effectiveness is limited by how well this state has compressed the context. To understand this principle, we focus on two running examples of synthetic tasks.
- The Selective Copying task modifies the popular Copying task by varying the position of the tokens to memorize. It requires content-aware reasoning to be able to memorize the relevant tokens (colored) and filter out the irrelevant ones (white).
- The Induction Heads task is a well-known mechanism hypothesized to explain the majority of in-context learning abilities of LLMs. It requires context-aware reasoning to know when to produce the correct output in the appropriate context (black).

These tasks reveal the failure mode of LTI models. From the recurrent view, their constant dynamics (e.g. the (A, B) transitions in (2)) cannot let them select the correct information from their context, or affect the hidden state passed along the sequence an in input-dependent way. From the convolutional view, it is known that global convolutions can solve the vanilla Copying task because it only requires time-awareness, but that they have difficulty with the Selective Copying task because of lack of content-awareness.
More concretely, the spacing between inputs-to-outputs is varying and cannot be modeled by static convolution kernels. In summary, the efficiency vs. effectiveness tradeoff of sequence models is characterized by how well they compress their state: efficient models must have a small state, while effective models must have a state that contains all necessary information from the context. In turn, we propose that a fundamental principle for building sequence models is selectivity: or the context-aware ability to focus on or filter out inputs into a sequential state. In particular, a selection mechanism controls how information propagates or interacts along the sequence dimension.

### 2.2 Selection Mechanism
One method of incorporating a selection mechanism into models is by letting their parameters that affect interactions along the sequence (e.g. the recurrent dynamics of an RNN or the convolution kernel of a CNN) be input-dependent.
In general, the selective mechanism is like wrapping an additional linear layer around the $B$, $C$, and $\Delta$ parameters to allow the model to select (or select to forget, which is similar to the idea of LSTMs). We specifically choose
$$
\begin{aligned}
& s_B(x) = Linear_N(x), \\
& s_C(x) = Linear_N(x), \\
& s_\Delta (x) = Broadcast_D(Linear_1(x)), \\
& \tau_\Delta = \text{softplus}. \\
\end{aligned}
$$
It is worth noting that adding dimension $L$ (sequence length) to $\Delta, B, C$ introduces the time element into the model, so the model is now time-varying. Thus, we can only use the recurrent formula, while the convolution formula is not appropriate with fixed parameters for every input. (For those who are not familiar with the $Broadcast$, it simply means expanding the dimensions of the smaller array.)
#### Specific Roles of $\Delta$, $A$, $B$, and $C$
Looking through the pseudo code above, you may become curious about why adding these extra linear layers can help us accomplish the “selective” mechanism. We hope you can build a clearer understanding of the “magic" behind it by reading the following about the roles of $\Delta$, $A$, $B$, and $C$.
**Role of $\Delta$**:
$\Delta$ controls the weight of inputs and determines whether to focus more on input or hidden states. When $\Delta$ is small, it means $x$ is not that important and we want to focus more on the hidden state (e.g., words that are not important, like “um,” can be considered as an input filter). On the other hand, a large $\Delta$ resets the hidden state on focuses on the current $x$. This explains why the input $x$ is projected down to 1 dimension and then broadcasted to $\Delta$ so $x$ can then be completely ignored. Also, $\Delta$ can be seen as the timestep in discretization, where a large $\Delta$ means the space between two discretized points is larger, indicating that we focus more on the input while ignoring the state (considering the hidden states in a continuous form, a larger step means fewer hidden states), and a small timestep means we focus more on the state.
**Role of $A$**:
Since $A$ only interacts with $\Delta$, (seen in $\hat{A} = \exp(\Delta A)$) and $\Delta$ already includes selectivity, adding the selective mechanism to $A$ may or may not help. The authors guessed that it may help but they didn’t run experiments on it.
**Role of $B$ & $C$**:
$B$ controls whether to let input $x_t$ into state $h_t$ or not. $C$ controls whether to let $h_t$ into output $y_t$. For $B$, it’s more like determining how much of the input can affect the context, and for $C$, it’s like determining how much of the context can affect the output.

Combining them all together, the whole story is like this: say at $h_{t-1}$, I feel tired reading the Mamba paper, and the input $x_t$ is eating a Big Mac. Let’s assume the current state $h_t$ is that I am now energetic. In this case, if we assume that having a Big Mac can make us energetic, maybe we should use a small $\Delta$ to ignore the previous state and focus more on the input (having a Big Mac). $B$ here will control how much having a Big Mac will affect my current state of feeling energetic. If we don’t believe a Big Mac is enough to change a person’s state, we may use a smaller $B$. As for $C$, it determines how much my state, in terms of feeling energetic, can influence the output; say the output $y_t$ is continuing to read the Mamba paper. If we believe that a good state of mind can ease the understanding of the paper, then we may use a larger $C$.
### 3. Architecture of Mamba
The following content refers to [7].
Now, we introduce the details of Mamba architecture.

Mamba architecture is similar to Transformer, but it replaces the attention blocks with Mamba blocks. The architecture of Mamba is shown in the figure below, where SiLU stands for Sigmoid Linear Unit as the activation function and SSM stands for State Space Model.

The architecture of Mamba Block is inspired by H3 Block and Gated MLP, with some modifications based on them. To understand Mamba Block, we first introduce H3 Block:
### 3.1 H3 Block
The architecture of H3 is proposed in 2023, which is the basis for the most well-known SSM architectures. H3 architecture is generally comprised of a block inspired by linear attention interleaved with an MLP (multi-layer perceptron) block.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/SkElCfVNC.png">
</p>
### 3.2 Mamba vs H3
As shown below, there are 3 main differences between Mamba and H3:
- Mamba combines the left two linear modules into one.
- Mamba expands the input dimension of SSM from $F$ to $E\cdot F$ with a linear projection(in practice they choose $E=2$).
An additional optional LayerNorm is added to the last layer.

### 4. Hardware-aware State of Mamba
The following content refers to [7].
(The following in this part refers to Section 3 in [7]) Because of the selection mechanism that S6 uses, Mamba is no longer time invariant, which means convolution, a linear time-invariant operation, can no longer be used to compute the output. Instead, Mamba leverages hardware-aware efficient algorithms, namely parallel scan, kernel fusion, and recomputation, to optimize the Selective State Machine (SSM) scan process, ensuring both speed and memory efficiency. By utilizing these techniques, Mamba's SSM scan is up to 7× faster than traditional attention mechanisms at a sequence length of 32K, while maintaining memory efficiency comparable to the most efficient attention implementations like FlashAttention, an IO-aware exact attention algorithm that minimizes the number of memory reads/writes between the GPU’s high bandwidth memory (HBM) and on-chip SRAM, and is designed to be more efficient in terms of IO complexity compared tos tandard attention mechanisms.
### 4.1 Parallel Scan: Computational Efficiency
As mentioned above, we can perform a parallel associative scan within SRAM, which allows us to process our calculations parallelly. To understand how a parallel scan works in Mamba, we will start by explaining what a “scan” actually means.
When saying “scan”, we typically mean the scan operation. A scan operation means we can compute a current value with the current input and the previous value. You can first recall the Prefix Sum Array, which contains value i that is calculated from the sum of an array from $0$ to $i$. For example, if we have an array:

For each element $i$ in the prefix sum, we only need to know the current value and the prefix sum on index $i - 1$, because prefix sum $i - 1$ represents the sum of all the elements from $0$ to $i - 1$. Looking back to the recurrent formula in Mamba, it is easy to find that it also shows the same property like prefix sum that to calculate the output we only need the current input at time step $t$ and the hidden state at $t - 1$.
The recurrent formula of the SSM model, $h_t = \bar{A}h_{t-1} + \bar{B} x_t$, $y_t = Ch_t$, can also be thought of as a scan operation, in which each hidden state is the sum of the previous state and the current input.
It seems that the scan operation must happen in $O(N)$ time, since there are n elements and seemingly n iterations of the prefix sum, but the scan can be parallelized. Although it seems that the scan operation is a sequential computation, the order of the operation does not matter. Let us explain the details of the parallel scan.

During the parallel scan, we have two stages, an up-sweep stage and a down-sweep stage, where each can be seen as in a binary tree structure. In the up-sweep stage, we will add every two numbers up which results in an array of [3, 7, 11, 15]. Keep doing this, we will get [10, 26], and finally, a root, [36]. Into the down-sweep stage, we first set the root to zero. Then we pass the parent value to its left child and calculate the right child by adding the parent value and the parent’s left child in the up-sweep stage. As shown in the above figure, we can get a final array [0, 1, 3, 6, 10, 15, 21, 28] which allows us to get the prefix sum array by adding it to the original array. And, yes, the whole processes in both the up-sweep and the down-sweep stages are parallel, which reduces the computational complexity from $O(N)$ to $O(logN)$.

The graph above benchmarks the speed of the SSM scan operation. The baseline, Scan (PyTorch), represents the parallel scan implemented in PyTorch without kernel fusion. Mamba's approach, Scan (ours), integrates kernel fusion with the parallel scan, achieving the best performance among all tested methods. Mamba's Scan is faster than the best known attention implementation, FlashAttention-2, for sequence lengths beyond 2K and is up to 20-40× faster than a standard scan implementation in PyTorch.
### 4.2 Kernel Fusion: I/O Optimization
Kernel fusion optimizes computation speed by combining multiple operations into a single kernel, reducing memory input/output (IO) operations. This is crucial because modern GPUs are often limited by memory bandwidth rather than computational power.
<p style="text-align:center;">
<img src="https://hackmd.io/_uploads/SJG_ymNVC.png">
</p>
In deep learning frameworks, tensors are loaded into the GPU's fast memory (SRAM), operations are performed, and results are saved back to high-bandwidth memory (HBM). Multiple operations on the same tensor involve several rounds of reading and writing to HBM, which is inefficient.
For example, what if we do 3 operations on the same tensor? Then, by default, the following process will occur:
- Load the input from HBM to SRAM, compute the first operation (CUDA kernel corresponding to the first operation) and then save back the result to HBM
- Load the previous result from HBM to SAM, computer the second operation and then save back the result to HBM
- Load the previous result from HBM to SRAM, compute the third operation, and then save back the result to SRAM.
Mamba addresses this by combining multiple operations into one custom CUDA kernel, performing all operations sequentially within SRAM. This minimizes memory IOs and leverages the GPU's fast computation capabilities, significantly speeding up operations. Specifically, Mamba’s approach to kernel fusion includes:
- Loading $O(BLF + FN)$ bytes of memory ($\Delta$, $A$, $B$, $C$) into fast SRAM from HBM.
- Discretizing to produce $\bar{A}$, $\bar{B}$ of size ($B$, $L$, $F$, $N$) within SRAM.
- Performing the parallel scan discussed above within SRAM to yield intermediate states of size ($B$, $L$, $F$, $N$)
- Multiplying and summing the results with $C$ of size ($B$, $L$, $N$), then writing them back to HBM.
This streamlined process achieves substantial speedups, making the scan operation 20-40 times faster compared to standard implementations.
### 4.3 Recomputation: Memory Efficiency
When we train a deep learning model, it gets converted into a computation graph. When we perform backpropagation, in order to calculate the gradients at each node, we need to cache the output values of the forward step, as shown below.

However, this can require significant memory usage. Recomputation optimizes memory usage by no longer storing the intermediate states of the input. Instead of saving intermediate states during the forward pass, Mamba recomputes them as needed during the backward pass, minimizing memory footprint while maintaining computational efficiency. This approach reduces the significant memory consumption of intermediate activations in deep learning by avoiding their storage and recomputing them only when necessary. Specifically, Mamba’s approach of Recomputation includes:
- Inputs and outputs are read and written from HBM to SRAM.
- Intermediate states required for the backward pass are recomputed rather than stored.
The memory cost of storing intermediate states is avoided, reducing the overall memory usage.

As a result, the fused selective scan layer has the same memory requirements as an optimized transformer implementation with FlashAttention, demonstrating the efficiency of Mamba's approach.
### 5. Experimental Performance of Mamba
The folloing content refers to [7].
In this part, we introduce some experiment results about Mamba, including the scaling law, performance on zero-shot tasks, and the throughtput during inference.
Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. This part focuses on its performance as a language model.
The paper evaluates the Mamba architecture on standard autoregressive language modeling against other architectures, considering both pretraining metrics (perplexity) and zero-shot evaluations. The model sizes (depth and width) are set to mirror GPT-3 specifications, using the Pile dataset.
### 5.1 Scaling Law

For baselines, the paper compares against the standard Transformer architecture (GPT-3) and the strongest Transformer recipe available (referred to as Transformer++), which includes enhancements based on PaLM and LLaMA architectures (e.g., rotary embedding, SwiGLU MLP, RMSNorm instead of LayerNorm, no linear bias, and higher learning rates). The above images show scaling laws under the Chinchilla protocol for models ranging from ≈125M to ≈1.3B parameters. As the result, Mamba is the first attention-free model to match the performance of a very strong Transformer recipe (Transformer++), especially as sequence length grows.
### 5.2 Downstream Evaluations

The above table shows the performance of Mamba on a range of popular downstream zero-shot evaluation tasks. As the result, the 3B parameter Mamba model excels in benchmark evaluations. Compared to open-source models at the 3B scale trained to the same token count (300B), Mamba surpasses all in every evaluation metric.
Notably, Mamba is competitive even with models at the 7B scale. For example, when comparing Mamba (2.8B) to OPT, Pythia, and RWKV (7B), Mamba not only achieves the best average score but also the best or second-best score on every benchmark.
### 5.3 Throughput

Here we show inference throughput of a Mamba 1.4B model and an untrained Mamba 6.9B model against a standard Transformer (GPT-3 architecture) at 1.3B and 6.7B sizes. Using the standard Transformer implementation from the Huggingface Transformers library, Mamba's structure and hardware-aware optimizer result in significantly faster inference, achieving 5× higher throughput compared to Transformers.
## Comparision between RWKV and Mamba
Here we show the experimental comparision between RWKV and Mamba, including multilingual benchmark and English-focused benchmark.
### 1. Multilingual Benchmark
The following content refers to [16].
Multilingual benchmarks include LAMBADA Multilingual (lmb.m), XCOPA, XNLI, PAWS-X, XStoryCloze (xsClz), xWinogrande (xwin).
In this benchmark, there is no consistency about which is better.

### 2. English-focused Benchmark
The following content refers to [16].
English Focused benchmarks include LAMBADA (OpenAI) (lmb.o), Hellswag (hella), PIQA, AI2 ARC (arc), GLUE, Winogrande (winG), SciQ, COPA.
In this benchmark, we find that Mamba is consistently better than RWKV.

### 3. In-context Learning
The following content refers to [17].
We show the results when testing 27 NLP tasks spanning a wide range of categories, including algorithmic tasks (e.g., list element extrac- tion), translation (e.g., English to Spanish), linguistic tasks (e.g., singular to plural conversion), and knowledge-based tasks (e.g., identifying country-capital pairs).
The following figure shows the average task accuracy for increasing context length, we find that Mamba consistently outperforms the similarly scalable RWKV at com- parable parameter sizes.
