# Mixture of Experts (MoE)
Authors: Chiao-Wei Hsu, Yuchen Wang, Ding Zhang, and Yufeng Zou
## MoE Introduction
MoE (Mixture of Experts) models have shown promise in various domains these years, including natural language processing, computer vision, and reinforcement learning. Traditionally, MoE is a neural network architecture that combines multiple expert networks to solve complex problems. Each expert is trained to leverage specialized knowledge for different aspects of the task. A gating network determines which expert to activate based on the input data, ensuring that the most appropriate expert handles each input. Nowadays, especially in this decade, with the tremendous advance in the field of Deep Neural Network and significantly increased amout of data and model complexity, it has been shown that a single end-to-end model can also be trained to learn different expertise or different tasks and modalities. This blog post explores the notable pieces of work, architecture, and training methodologies of MoE models, highlighting their potential for improving model performance and efficiency.
<!-- ## Outline:
1. History:
1. Introduction: Experts as Components, Conditional Computation
2. Models:
1. GShard (yf)
2. Switch (yf)
3. Mistral
2. Architectural differences between the vanilla Transformer and Mistral (zd)
3. Sliding Window Attention (zd)
4. KV-Cache (yuchen)
5. Sparse Mixture of Experts (zd)
6. Model Sharding / Expert Parallelism (yuchen)
7. Stabilizing training with router Z-loss (Chiao)
8. Capacity Factor and Communication costs (Chiao)
-->
## Outline
The outline of this blog will be illustrated as follows. We will first give a general introduction to the fundamental concepts that led to the thrive of Mixture of Experts models in the deep learning era, namely the idea of **sparsity** and **gating mechanism**. We will then discuss **training methodologies** for MoE models, including specialized loss functions, regularization techniques, and fine-tuning strategies. Then we will delve into the **Mixtral** models by MistralAI, exploring their unique features and design choices sliding window attention, group query attention, KV-cache, and prefill and chunking mechanisms. We will also discuss other famous MoE models, for example, **GShard** and **Switch Transformer** models developed by Google. We will conclude by providing some future research topics in the area of MoE.
## Sparsity in MoEs
Mixture of Experts models can trace back decades ago from the early 90s. In 1991, Jacobs et al first utlizes the idea of mixture of experts in the paper titled [Adaptive Mixutres of Local Experts](https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf?ref=magicslabblog.ghost.io) in the supervised learning setting. The idea was to model a prediction $y$ using the weighted sum of experts $E$
, where the weights are determined by a gating mechanism $G$. It is essentially a divide and conquer idea: divide a large and complex problem into distinct and smaller sub-problems. The authors in this paper presented a vowel discrimination problem using different experts specializing in different decision boundaries. However, the problem with the vanilla implementation of the gating mechanism is that the computational complexity of these expert networks is significantly increased, and thus making training costs enormous.
The game changer came into play in 2017, where the paper [Outrageously large neural networks](https://arxiv.org/pdf/1701.06538?ref=magicslabblog.ghost.io) first introduced the idea of **top-k routing**, that makes the network a lot more "sparse" (also the notion of conditional computation). Sparsity uses the idea of conditional computation, where in dense models all the parameters are used for the inputs, sparsity allows only specific parts of the network to be activated for each input. This keeps the number of FLOPs constant when adding more and more experts. This set up allows one to scale up the size of the model without increasing the computation costs. Nowdays, the style of MoE layers used by most modern LLM models are similar to generic architecture purposed in the paper, consisting of two components:
* Experts: each layer has several "experts" that are neural network modules or layers with independent sets of parameters with each other.
* Router: a parametric (and learnable) gating mechanism that selects a sparse subset of experts from all the experts to process a given input.
The next several milestone MoE models like Switch Transformer, Mistral, etc, all utilizes this sparsity idea. In this [blog](https://towardsdatascience.com/understanding-the-spare-mixture-of-experts-smoe-layer-in-mixtral-687ab36457e2?ref=magicslabblog.ghost.io), the author gives a detailed explanation of the foundamental ideas that thrived MoE architectures.
## Gating Function in MoE
### Gating mechanism
How do we choose which experts to use? Many different strategies have been proposed for routing within an MoE. The simplest approach would be to multiply our input by a weight matrix and apply softmax; see section ***Sparsity in MoE***. However, this approach does not guarantee that the selection of experts will be sparse. To solve this issue, authors in the aforementioned "***Outrageously large neural networks***" propose a modified gating mechanism that adds sparsity and noise to this simplistic softmax gating mechanism.

(image from: https://cameronrwolfe.substack.com/p/conditional-computation-the-birth#footnote-anchor-1-142423094)
The gating mechanism performs routing similarly to the softmax gating mechanism, but it includes two additional steps:
* An adjustable amount of Gaussian noise is added to the output of the router prior to applying softmax.
* All but the output of the top-K experts are masked (i.e., set to -∞) to ensure that the selection of experts is sparse.
In the Mixtral paper as well as other popular MoE implementations, this learned gating network decides which experts to process the input:
$$
y=\sum^i_{n=1} G(x)_i E_i(x)
$$
$$
G_{\sigma}(x):=Softmax(TopK(x \cdot W_{g}))
$$
Here, $G(x)_i$ denotes the $n$ dimensional output of the gating network for the $i$-th expert, and the $E_i(x)$ is the output for the $i$-th expert. This is a weighted multiplication, and all experts are run for all the inputs. If $G$ equals to 0, then we don't need to compute the respective expert operations and save computation time. Similar as above, the gating function used in the Mixtral paper is also a softmax function, and it takes in the top-K logits of a linear layer. Note that the softmax function is applied after selecting the top-k experts. If we apply the softmax function directly to the output of the gate function, this will result in a probability distribution over all experts. That is incorrect because we are only going to use the top-k experts among all experts. This also benefits when comparing two models trained on different number of experts, as the sum of the weights applied to the output will be 1 independently on the number of experts chosen by the gate function.
### Top-K experts in gating
Having a lower $k$ number (e.g. $k=2$ in Mixtral 8x7B), this kind of routing ensures that the computational cost is controlled because only a small subset of the experts are activated for each input, which makes the model more efficient during both training and inference times compared to dense models where all parameters are always active. Why don't we just choose the top performing expert, i.e. $k=1$ like the Switch Transformers architecture, but instead choosing two experts? It is likely chosen as a balance between model complexity and computational efficiency. The use of Top-2 routing in Mixtral allows for a sparse activation pattern where only the top two experts (in terms of gating network output) are utilized for a given input. This offers a trade-off that allows for more diversity and expert utilization than Top-1 routing, potentially increasing the representational capacity of the model without the full computational load that would come from activating all experts or a larger number of them (we will give more thorough introduction in the following sections). In practical applications, such as improving large language models (LLMs), Top-k routing with k=2 is used to merge domain-specific expert models that have been trained separately on specialized data sets into a single model that can utilize the specialized knowledge of each expert where applicable. This means that for any given input, the two most relevant experts are utilized, allowing the model to leverage specialized knowledge without overwhelming computational costs.
(Image from: https://cameronrwolfe.substack.com/p/conditional-computation-the-birth#footnote-anchor-1-142423094)
In summary, the gating function is an elegant way to only focus on the experts one want and thus fully utilizing the idea of sparsity. It is the key to the success of Mixture of Experts models.
## Training MoEs
As we discussed before, the Mixture of Experts (MoE) model is a paradigm shift in machine learning, offering a divide-and-conquer approach to complex problems. By dividing the task among multiple specialized "experts", a MoE model can process data more efficiently and effectively. So far, we have introduced how a MoE is typically constructed and implemented practically. This section delves into the training methodologies for different types of MoEs, providing insights into the intricacies of this innovative architecture.
At its core, an MoE model consists of two primary components: the expert networks and the gating network. The goal of the training is for the expert networks to be specialized neural networks, each becoming an "expert" adept at handling specific aspects of the overall task. And for the gating network, often another neural network, it should be trained to acts as a traffic controller, directing input data to the most appropriate expert based on learned parameters. However, A common misnomer about training the MoE Architecture is that the gating network and the expert networks are trained directly and separately. This might be true in the past but nowadays? Not as usual anymore. Instead, this training process is done on the whole monolithic architecture that contains both the gate network and the expert networks, each of which is trained simultaneously. In fact, this paradigm shift of trainig process from the separate pipeline training to end-to-end training has been observed in the entire field of Deep Learning other than MoEs, since it is less hand-crafted, simpler to implement and also leads to a better local minimum.
The training of the entire networks follows a standard deep learning approach, typically involving gradient descent methods. This method has two phases: forward propagation and backward propagation.
During the forward pass, the gate network assigns a probability for each network based on the hidden representations of the input sequence (zero probability for those outside top-N). The remaining part of the architecture also generates the corresponding output. At the end, all neurons have their outputs and a final prediction is obtained along with a specifiic training loss calculated by comparing with the benchmark. During the backward pass, since each expert is inherently seeing a subset of the data, it can be trained to develop a unique specialization. There are actually different losses that can be used to train MoEs, and some of them are utilized simultaneously. A common training loss is the cross-entropy loss, which measures the difference between the predicted and actual output probabilities. Since this loss is based on the final prediction of the whole architecture, the gradient is actually backpropagated to both the gate network and the expert network, thus a simultaneous training process that propagate the gradient all the way from the loss to the gate and each expert.

One thing worth mentioning is that based on different inputs it sees (could be a token, or a whole sentence depending on different architectures), the gate network will generate different probability distributions of which experts are preferred. For example, the probability of an expert will be higher if the training example is more aligned with its specialization, and lower if this example is "out of its league". In other words, each expert sees a different portion of the examples. This process ensures that when the model encounters a specific type of input, it can leverage the expertise of the most qualified sub-network.
## Specialized Loss Functions and Regularization
### Balancing the Experts
One of the challenges in training MoEs is balancing the load among experts. This blog introduces this exact problem of MoE structure that the network often relies on the same few experts during training. Rather than evenly distributing the workload among all experts, the gating mechanism tends to favor a specific subset, consistently selecting them for every input. This creates a feedback loop: the more frequently an expert is chosen, the more quickly it is trained, and the more likely it will continue to be selected over others. To address this imbalance, a straightforward "soft" constraint can be added to the training loss, as described below.

(image from: https://cameronrwolfe.substack.com/p/conditional-computation-the-birth#footnote-anchor-1-142423094)
The way to solve this issue involves ensuring that no single expert becomes a bottleneck, which could lead to inefficiencies. Additionally, scaling the number of experts impacts pretraining, as it requires careful consideration of the model's capacity and the computational resources available. The goal is to achieve a balance where the model scales effectively without compromising performance.
This target of load balancing presents itself as an auxiliary loss that is added to the loss introduced above as in
$$
\text{Total Loss} = \text{Task Loss} + \sum \text{All Specialized Losses}
$$
There are multiple lines of work implementing different variants. Essentially, most of them involve reducing the variance (e.g., by minimizing the coefficient of variance, CV) of the probability of the activation of each experts. For example, we illustrate the solution presented in the aforementioned "Outrageously Large Neural Networks" paper. First, we define an "importance" score for each expert over a batch of input data. This score is calculated by summing the gate values for each expert across the batch. In simple terms, experts that are frequently selected within the batch will have a high importance score. Next, we compute an auxiliary loss function using the squared coefficient of variation (CV) of the expert importance scores. If the importance scores of the experts are very similar, the CV will be small, and if they differ significantly, the CV will be large. This auxiliary loss term can then be added to the model's training objective, promoting equal importance for all experts within each batch.
Furthermore, specialized loss functions and regularization techniques also play a crucial role in enhancing the training process of Mixture of Experts (MoE) models.

To enhance the training process, specialized loss functions and regularization techniques are employed. For instance, the router Z-loss function (ST-MoEs) helps distribute the workload evenly among experts, preventing the "rich-get-richer" phenomenon. This loss function can be weighted to adjust its impact on expert utilization, ensuring a fair distribution of tasks. Regularization techniques, such as dropout or L2 regularization, are also used to prevent overfitting and improve the generalization of the MoE model.
$$
L_z(x)= \frac 1 B \sum_{i=1}^B \left ( log \sum_{j=1}^Ne^{x_j^{(i)}} \right ) ^ 2
$$
where $\textit{B}$ is the number of tokens, $N$ is the number of experts, and $x \in R^{B\times N}$ are the logits going into the router.
Note how the router Z-loss function has a form of L2 regularization, but different from the traditional L2 regularization applied to the weights of the model ($[-\inf, +\inf]$), it is applied to the probability output ($[0, 1]$) of the gating network. This is designed to encourage a more uniform distribution of tasks among experts, thereby improving the overall efficiency of the MoE model. By incorporating such specialized loss functions, MoEs can achieve better load balancing and performance optimization.
## Mistral and Mixtral
We have introduced the key design concepts and training procedures for a general Mixture of Experts model. In the next few sections, we will discuss the current state-of-the-art MoE models in the market and illustrate the technical details of these models. We will start with the latest model that we have also covered in class, Mistral-7B.
**Mistral**, or more formally, **Mistral-7B**, was first introduced in this blogpost by Albert Jiang, et al. The model is open-source, and it is also the first large language model (LLM) released by the company, mistral.ai.
Mistral-7B is a transformer model designed for handling fast inference and longer sequences. It is a decoder-only Transformer with the following architectural choices:
* Sliding Window Attention
* Grouped Query Attention (GQA)
* KV Cache
* Pre-fill and Chunking
With these carefully designed architectures, Mistral-7B is able to handle tasks with longer sequences more effectively at a reduced cost. It takes a significant step in balancing the goals of achieving high performances while at the same time keep the large language model efficient.
The company then takes one step further, introducing **Mixtral 8x7B**, which is a Sparse Mixture of Experts language model. It employs a mixture-of-experts architecture that dynamically selects a certain number of experts for processing each token based on a gating mechanism, which also allows it to handle a large number of parameters efficiently during inference. Specifically, while the model has a total of 46.7 billion parameters, it only uses around 13 billion active parameters per token, which enhances both speed and efficiency compared to other large models like GPT-3.5 and Llama 2 (side note, according to the name of the model, technically there should be a total number of 8*7=56 billion parameters; the reason is MoE is not simply just an ensemble of 8 models with 7B parameters, rather, only some layers of the model are replicated).

(Image from:https://cameronrwolfe.substack.com/p/conditional-computation-the-birth)
MoE modifies the architecture to the decoder-only transformer architecture used by most generative LLM shown above. The main modification is the feed-forward sub-layer is replayed with an MoE layer. This MoE layer is comprised of several experts(from few experts to thousands), where each expert is its own feed-forward sub-layer with an independent set of parameters. The modified architecture is shown below.

(Image from:https://cameronrwolfe.substack.com/p/conditional-computation-the-birth)
In the next few sections of this blog, we will provide a detailed explanation on the components of the MoE architecture, and the reasons behind these designs. Both Mistral and Mixtral models are open-source, available for download on HuggingFace.
### Sliding Window Attention
**Problem with long input tokens**
Recall that the success of transformers is highly dependent on the self-attention mechanism. However, the nature of the Transformer architecture suffers from the maximum limitation of input size to 512 tokens. The input tokens are used as "keys" in the self-attention layers, which are the sequence representations, and "queries" that can attend to these keys, thus attends to itself. For example, let's assume a 5-token input sequence; for each token in the input sequence to be able to attend all keys (fully connected), this requires a quadratic $O(n^2)$ memory complexity per attention layer. This type of attention layer is known as the full attention or quadratic attention layer. A good way of thinking this is to represent the layer connectivity as an n\*n matrix. The memory requirment for this attention layer is the number of rows (n) times the number of columns (n), which is indeed $O(n^2)$. Thus, when the attention layer receives a large input sequence, the quadratic complexity makes it significantly inefficient for the transformer model computations. In some cases, the output may depend on long-distance attention between the document tokens (a word in the first chapter has been referenced in the fifth chapter, for example). Such long attention is not achievable in BERT-like models.
**Sliding Window Attention**

*Sliding Window Attnetion, [image source]( https://medium.com/@gopalgoyal612002/mistral-llm-architectural-details-8dc0447fea62).*
This [blog](https://ahelhady.medium.com/understanding-longformers-sliding-window-attention-mechanism-f5d61048a907) gives a comprehensive explanation of the mechanism of Sliding Window Attention. Sliding window attention is an attention pattern for attention-based models. It is first being purposed in the [LongFormer's paper](https://arxiv.org/abs/2004.05150v2) as an attention mechanism. The mechanism tries to overcome the issue of limited input sequence length in aforementioned classical transformer models like BERT, by suggesting a convolution-like architecture for the attention mechanism. It defines a window of size $W$, such that the query node is allowed to attend only to $W$ of its neighbours inside the window. In the figure below, we show an attention window of size 3, where the highlighted node in green is allowed to attend to the peer key (middle node) and its immediate neighbours on the left and on the right.

The key assumption behind sliding window attention is that the most important information to the word is its local neighbours, with size $k$. This results in a memory complexity reduction to $O(nW)$, which is significantly efficient for $W << n$.
However, one may be wondering that: didn't apply this sliding window attention losing information from key nodes outside the window size $W$? How would the sliding window problem solves the afore-mentioned problem when two words are far apart with each other in the chapters but still have unnegligble relationships? Well, if you look at the level of a single attention layer you may think so. But, when we stack multiple attention layers together, at higher layers, the query node gains attention information from far neighbors but in different representation way. The idea is very similar to the **receptive field** in CNN. In the level of a single attention layer, the key nodes sitting outside the window size of $W$ are discarded. But as we move on to the next layer, each node contains aggregated information of the nodes propagated from the previous layer. Thus, we end up with a conical structure for each token’s attention, starting with the local attentive nodes to their $W$ neighbors, but at higher layers, the attention gains information from tokens far away from it (global attention).

### Grouped Query Attention
As the name reveals, Grouped Query Attention focuses on groups of words together instead of one word at a time. It is a technique introduced to optimize the balance between computational efficiency and model performance. It interpolates both the speed of multi-query attention (MQA) and with the quality of multi-head attention (MHA).
In Graph Query Attention, query heads are divided into groups, each of which shares a single key head and value head. For example, GQA with one group (and therefore one key and value head) is equivalent to MQA, while GQA with groups equal to the number of heads is equivalent to MHA.


GQA significantly reduces the computational load and memory requirements, making it possible to work with larger models or data without sacrificing performance. With lower resource demands, models using GQA can scale more effectively, handling more extensive and complex tasks. Please refer to this [blog](https://klu.ai/glossary/grouped-query-attention) for more information.
### KV-Cache
In this section, we are going to explain what is a KV-Cache. Here, K stands for key value and V stands for V value. So KV cache is a key-value chaching systems. data is stored in the from of key-value pairs where each key is unique and maps to a pecific value. When a key-value pair is cached, the key acts as an identifier that is used to quickly retrieve the corresponding value from the cache, without needing to compute the value again or retrieve it from a slower data storage.
The intuition behind KV Cache is we are only interested in the last token output by the model.KV Cache is to help to have less computation to have the last token during inference.

Above is a gif image to show the processof the KV-Cache computation. We could clearly see from the gif that without cache will store more values and do more unneccesary computations. With Cache method only keeps the last row of the results because we do not care about the previsou tokens.
A KV Cache mechanism is often emoployed by the generative process of Large Language Models (LLMs) often employs to speed up output generation. This technique involves storing previously computed Key/Value vectors from the attention calculation and reusing them when generating new tokens, thus bypassing the need to recalculate for past tokens.
We could verify how the KV cache speed up the process through the ***computations***. The calculations are provided by [Transformer Inference Arithmetic](https://kipp.ly/transformer-inference-arithmetic/#kv-cache).
But before we start our computations, we need to first familiar oueselves with some basic terminology that used in subsequent discussions:
* ***Floating point operations per second*** ([flops](https://en.wikipedia.org/wiki/FLOPS)): A
flop serves as a basic unit of computation, which could denote one addition, subtraction, multiplication or
division of floating point numbers. Note that, the flop count is just a rough measure of how expensive an
algorithm can be.
* ***flop bound***: It would then mean that there is time when nothing is being passed through memory.
* ***memory bound:*** It would mean that no floperations are occuring. Loading weights could consume memory bandwidth.
Per token, the numebr of bytes we store is
$$
2 \cdot n_{\text{layers}} \cdot n_{\text{heads}} \cdot d_{\text{head}}
$$
, where 2 is to account for the two vectors, k and v. We store KV paris per layer, and each of thoes values is a n_heads * d_heads matrix.
The flops to compute ka dn v for all our layer is
$$
2 \cdot 2 \cdot n_{\text{layers}} \cdot d_{\text{model}}^2
$$
It takes $2 \cdot d_{\text{model}}^2$ to multiply each token embedding by each token weight. We have another factor of 2 as we do that twice, once each for k and v and then repeat for $n_{\text{layer}}$.
This means for a ***52B*** parameters model(taking Antrropic's, where $d_{\text{model}}$ = ***8192*** and $n_{\text{layer}}$ = ***64***). The flops are
$$
2 \cdot 2 \cdot 64 \cdot 8192^2 = 17,179,869,184
$$
Say we have a NVIDIA A100 GPU, whcih does ***312e12*** flops per second and ***1.5e12*** bytes per second of memory badwidth. The following are numbers for just the kv weights and computations.
$$
memory = \frac{2 \cdot 2 \cdot n_{\text{layers}} \cdot d_{\text{model}}^2}{1.5e12 }
$$
$$
compute = \frac{2 \cdot 2 \cdot n_{\text{layers}} \cdot d_{\text{model}}^2}{312e12 }
$$
So here we get a ratio of 208 (result of $\frac{312e12}{1.5e12}$) given this hardware specification. This means if we're going to compute kv for one token, it'll take the same amount of time to compute for up to 208 tokens! For fewer than 208 tokens, the system is memory bandwidth-bound, implying that we cannot fully leverage the computational operations. Beyond 208 tokens, we are computation-bound, meaning that memory is not fully utilized.
The intersection of the below diagram is at 208, though in reality the memory line does have a slight slope due to memory cost of intermediate calculations.

Assume that the context length is 6, then for a 52B model full forwards pass, that's $\frac{12 \cdot 2 \cdot n_{\text{layers}} \cdot d_{\text{model}}^2}{1.5e12} \approx 69$ milliseconds for up to 208 tokens. If we had 416 (double) tokens in the context, then it would take twice as long, and 312 tokens would take 1.5 times as long.
Calculating for a kv cache token is exactly 1/6 of the compute of passing the token through the model. In general, these forwards passes (what we experience in getting logits, embeddings and training) are very cheap because of the parallelism that is possible as opposed to sampling where we're forced to read through all the weights for each token and do the autoregression.
This doesn't mean that only 1/6 of the time is saved! Let's assume we are flops bound. Then at each sample step, we save $\frac{12 \cdot 2 \cdot n_{\text{layers}} \cdot d_{\text{model}}^2}{312e12}$ flops while the decoding steps costs $\frac{12 \cdot 2 \cdot n_{\text{layers}} \cdot d_{\text{model}}^2}{312e12}$. Thus at each step we save 1/6 of the slops time multiplied by the number of tokens in our sequence (big!) — which increases as we sample tokens. Without a kv cache, sampling would be quadratic in time complexity as we increase the number of tokens.
However, while KV Cache is an effective strategy for efficiency, it significantly raises memory usage. This becomes more pronounced with larger models and longer text generations, leading to substantial demands on device memory resources [MODEL TELLS YOU WHAT TO DISCARD:
ADAPTIVE KV CACHE COMPRESSION FOR LLMS](https://arxiv.org/pdf/2310.01801.pdf).
### Rolling Buffer Cache
In the original [mistral paper](https://arxiv.org/abs/2310.06825), the author mentions rolling buffer cache. To effectively manage cache size, they employ a rolling buffer cache that leverages a fixed attention span. This cache maintains a fixed size of \( W \), where the keys and values at timestep \( i \) are stored at position \( i mod W \). This means that once \( i \) exceeds \( W \), older values in the cache are overwritten, preventing the cache from growing indefinitely. For example, as illustrated below with \( W = 3 \), this method reduces cache memory usage by 8x for sequences as long as 32k tokens, all without impacting the model's performance.

*([Image source from the Mistral paper](https://arxiv.org/abs/2310.06825))*
The block that colored organge are responsible for generating the next token in the hidden state. From time (i+2), we could see that 'of' overwrites 'this' in timestrp (i+1) for the first sentence.
### Prefill and Chunking
In the original [mistral paper](https://arxiv.org/abs/2310.06825), the authors mention the trick of prefilling and chunking. When generating a sequence given a prompt, since the prompt is known in advance, we can ***prefill*** the KV-cache with the prompt. But what if the prompt size is very large? In that case, the cache may not optimizely work. We can fill in the tokens into the cache one by one, but that is extremely computational inefficient. Is it possible to strike a balance between filling the KV cache one token at a time and exhuasting the KV cache by the full prompt? The answer is ***chunking***: we chunk the prompt with the window size and then prefill the KV-cache, and for each chunk, we compute the attention score over the cache and over the chunk. In the paper's example, the sentence is "The cat sat on the mat and saw the dog go to".

*([Image source from the Mistral paper](https://arxiv.org/abs/2310.06825))*
With a window size of 4, the sequence is processed into three chunks, "The cat sat on", "the mat and saw", and "the dog go to". The figure above shows what happens if the third chunk ("the dog go to") comes into process: it attends itself using a usual causal mask, while the center block ("the mat and saw") attends the cache using a sliding window attention, and does not attend to past tokens as they are outside the sliding window (leftmost block).
## Model Sharding / Expert Parallelism
Model sharding is a technique used in deep learning to handle large models that cannot fit entirely into the memory of a single computing device, such as a GPU or CPU. This situation often arises with large-scale deep learning models, such as those found in natural language processing (e.g., GPT-3) and computer vision. Sharding involves dividing the model's parameters across multiple devices, allowing parallel processing and memory distribution.
The model's parameters are split into distinct subsets, called shards. Each shard contains a portion of the model's total parameters. Each computing device or node in a distributed system handles one or more shards. This setup means that each part of the model is processed in parallel across different devices, effectively utilizing the combined memory and computational power of the systm. While model sharding allows for the training of very large models by leveraging multiple devices, it also introduces the need for these devices to communicate with each other. This communication typically involves synchronizing the gradients or parameters during the training process, which can be a significant overhead.
And then we need the model parallelism to parallel the data we divide. For more information, you can refer to [How to Parallelize Deep Learning on GPUs Part 2/2: Model Parallelism](https://timdettmers.com/2014/11/09/model-parallelism-deep-learning/).
Model parallelism is, when you split the model among GPUs and use the same data for each model; so each GPU works on a part of the model rather than a part of the data. In deep learning, one approach is to do this by splitting the weights, e.g. a 1000×1000 weight matrix would be split into a 1000×250 matrix if you use ***four*** GPUs.
However, model parallelism is not the best way to do the data paralelism. This is discussed in [Model Parallelism](https://huggingface.co/transformers/v4.10.1/parallelism.html). The problem is there is one GPU is idle at any given moment. So if 4 GPUs are used, it's almost identical to quadrupling the amount of memory of a single GPU, and ignoring the rest of the harware. Plus there is the overhead of copying the data between devices. So 4x 6GB cards will be able to accommodate the same size as 1x 24GB card using naive MP, except the latter will complete the training faster, since it doesn’t have the data copying overhead. But, say, if you have 40GB cards and need to fit a 45GB model you can with 4x 40GB cards (but barely because of the gradient and optimizer states)
This photo is from [Introducing GPipe, an Open Source Library for Efficiently Training Large-scale Neural Network Models](https://research.google/blog/introducing-gpipe-an-open-source-library-for-efficiently-training-large-scale-neural-network-models/), on the top is model parallelism(MP) and the bottom is pipeline parallel(PP).

We could see that PP has less zones where GPUs are idel. The idle parts are referred to as the "bubble".
Here GPU0 performs the same forward path on chunk 0,1,2, and 3(F0,0, F0,1, F0,2, F0,3) and then it waits for other GPUs to do their work and only when theri work is starting to be completed, GPU0 starts to work again doing the backward path for chunks 3, 2, 1(B0,3, B0,2, B0,1, B0,0).
Because of the chunks, PP introduces the concept of micro-batches (MBS). DP splits the global data batch size into mini-batches, so if you have a data prallel degree of 4 (which means we have 4 GPUs), a global batch size of 1024 gets split up into 4 mini-batches of 256 each (1024/4). And if the number of chunks is 32 we end up with a micro-batch size of 8 (256/32). Each Pipeline stage works with a single micro-batch at a time.
With chunk = 1, you get the naive MP. With a very large chunk number, you end up with tiny micro-batch sizes which could be not very efficient either. So we need to finetune the chunk number to achieve the highest efficient utilization of the GPUs.
However, in this scheme, only one micro-batch is processed at a time, which results in all stages except one being idle. After completing the final stage, the process reverses to begin the backward pass, again leaving all stages but one idle. Once the entire batch is computed (after ramp-down), a weight update occurs. Consequently, the utilization rate is at most 1/𝑁 of full capacity, indicating a significant inefficiency.
We could use ***grouped pipeline*** to achive more efficient performance:

(image from: https://afmck.in/posts/2023-02-26-parallelism/)
In the grouped pipeline method, the process is divided into stages that work on different micro-batches simultaneously. This contrasts with the previous method where only one micro-batch was in process at a time.
This grouped pipeline approach ensures that multiple stages are utilized simultaneously, minimizing idle time and significantly improving efficiency compared to the initial scheme where only one stage was active at a time.
Another approach is to ***interleave*** the forwards and backwards passes:

(image from: https://afmck.in/posts/2023-02-26-parallelism/)
At any moment, half of the stages are performing a forward pass, while the other half are executing a backward pass. This arrangement shortens the ramp-up and ramp-down phases, leading to faster achievement of maximum utilization.
The grouped and interleaved schemes each have their own advantages and disadvantages:
* The grouped scheme processes twice as many mini-batches simultaneously compared to the interleaved scheme, which requires more memory for storing activations.
* The grouped scheme processes all forward and backward passes together, reducing the frequency of communication. In contrast, the interleaved scheme processes them separately, leading to more communication and some idle time when forward passes wait for backward passes, which generally take longer. As a result, grouped schemes are typically faster than interleaved ones.
* The interleaved scheme has ramp-up and ramp-down times that are roughly twice as fast as the grouped scheme, allowing it to reach full utilization more quickly.
### Other Parallelism
* Tensor Parallelism
Except for pipeline parallelism, we also have tensor parallelism to shard individual layer of the model into smaller, independent block of computation taht can be executed on different devices.

(images from: https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/)
In Figure (a), we split the tensor A into two tensors A1 and A2, and then given an input x, we calculate X * A1 and X* A2 simultaneously. Figure (b) is an example of two-way tensor parallelism in the self-attention layer.
* Data Parallelism
Data parallelism is the simplest form of parallelism. The data including the optimiser parameters and code is copied to different micro-batches to each device. The gradients are then synchronised between devices, followed by an optimiser step resulting in parameter updating.
There is one thing to note is data parallelism needs to copy al data to all devices, so it can not help with larger models, it could only help with the larger batch siezes.
## GShard
GShard, introduced in a [paper](https://arxiv.org/abs/2006.16668) in 2020, is a model architecture developed by Google for scaling up neural networks, particularly for tasks like natural language processing and translation. The key innovation of GShard is its ability to efficiently train very large models by utilizing a technique called "model sharding", as an alternative to [GPipe](https://arxiv.org/abs/1811.06965). This technique distributes different parts of a neural network model across multiple GPUs or TPUs, which allows the model to be scaled up significantly without running into memory or computational constraints on individual devices.
GShard implements a form of conditional computation that activates only certain parts of the network based on the input it receives. This selective activation helps optimize computation resources and speeds up the training process, enabling sublinear scaling of the computation cost. In both the encoder and the decoder of GShard, Transformers are sparsely scaled with conditional computation by replacing every other feed-forward layer with a Position-wise Mixture of Experts (MoE) layer. The MoE layers adopt random **top-2** gating, where the top expert is always picked but the second expert is picked with probability proportional to its weight. Techniques like capacity limit on the number of input tokens each expert can process and auxiliary loss for balancing the load across experts are also employed. These will be further discussed in the next section. When scaling to multiple devices, the MoE layer is sharded across devices while all the other layers are replicated, as illustrated in the figure below.

***Illustration of scaling of Transformer Encoder with MoE Layers. The MoE layer replaces every other feed-forward layer. Decoder modification is similar. When scaling to multiple devices, the MoE layer is sharded across devices, while all other layers are replicated. Image is from the GShard paper.***
A key feature of GShard is its simple API for annotations. These annotations can be added to existing TensorFlow code to indicate which computations should be parallelized. In addition, GShard includes a compiler extension in [XLA](https://openxla.org/xla) (Accelerated Linear Algebra) that handles the automatic parallelization of these computations.
One of the notable applications of GShard is in machine translation, where it demonstrated state-of-the-art performance by scaling up to models with over 600B parameters. By leveraging the capacity of massive models, this approach enables more effective handling of multiple languages and improves translation quality. Some key observations are made from the results of the multilingual translation experiments, which are plotted in the chart below:
- Deeper models bring consistent quality gains across the board.
- Relaxing the capacity bottleneck grants pronounced quality gains. With the depth of 12L, increasing the number of experts per-layer from 128 to 512 yields +3.3 average BLEU score gain across 100 languages. However, scaling the number of experts per-layer from 512 to 2048 yields only +1.3 average BLEU scores gain.
- Having more experts improves quality especially for high-resourced tasks because of increased model capacity. While adding more experts relaxes
the capacity bottleneck, it reduces the amount of transfer due to a reduction of the shared sub-networks, negatively affecting the quality for low-resourced tasks. This pattern can be observed by comparing the performance of MoE(2048E, 36L) and MoE(512E, 36L) models.
- Deep-dense models are better at positive transfer towards low-resource tasks because of sufficient parameter sharing, e.g., T(96L) model.

***Translation quality comparison of multilingual MoE Transformer models trained with GShard. Positions along the x-axis represent languages from high to low resource. The y-axis represents the quality gain of a multilingual model compared to a monolingual Transformer model trained and tuned for a specific language. Dashed line represents a 96-layer multilingual Transformer model T(96L) trained with GPipe on same dataset. Image is from the GShard paper.***
## Switch Transformers
The [Switch Transformer](https://arxiv.org/abs/2101.03961) model, published by Google in 2022, employs a sparse [T5](https://en.wikipedia.org/wiki/T5_(language_model)) encoder-decoder architecture, substituting the MLP with a Mixture of Experts (MoE) layer. Its architecture features a routing mechanism (**top-1** in this case) that associates each token to an expert, each being a dense MLP. Despite having a lot more weights than equivalent dense models, the Switch Transformer's sparsity enables better performance at larger scales throughout both pretraining and finetuning phases.
Switch Transformers scale from the base version containing <7B parameters to the large version containing 1.6T parameters with 2048 experts. During a forward pass, only a fraction of the weights are used. The routing mechanism allows the model to select relevant weights on the fly which increases the model capacity without increasing the number of operations. Experiments show that Switch Transformers have significant speedup in pretraining over T5 counterparts and achieve better performance on downstream tasks such as multilingual learning.

***Illustration of a Switch Transformer encoder block. Image is from the Switch Transformer paper.***
Several important aspects of this work:
**Switch routing**  routes the input to the top-1 expert only, so an MoE layer is also called a Switch layer in this work. Suppose the gate value for expert $i$ is $G_i(x)$, then the output of the Switch layer is $y=G_j(x)E_j(x)$, where $j=\text{arg}\max_i G_i(x)$. This reduces the computation of experts and communication costs.
**Expert capacity**  sets a limit on the number of inputs each expert can process, determined by $\text{expert capacity}=\frac{\text{# tokens per batch}}{\text{# experts}}\times\text{capacity factor}$. When an input is routed to an overloaded expert, the token representation is passed directly to the next layer through the residual connection. A capacity factor greater than 1 creates additional buffer to accommodate for the case when tokens are not perfectly balanced across experts. Increasing the capacity improves the quality but leads to more communication costs and memory of activations. Switch Transformers perform well at low capacity factors (1-1.25).
**Load balancing loss**  A differentiable load balancing loss is introduced to encourage a balanced load across experts. For each Switch layer, the auxiliary loss is added to the total model loss during training. This loss encourages uniform routing and can be weighted using a hyperparameter. Please see the [corresponding section](#specialized-loss-functions-and-regularization) for more discussions.
**Training and finetuning techniques**
* Selective precision for efficiency, such as training the experts with [*bfloat16*](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) while using full precision for the rest of the computations.
* Smaller parameter initialization for training stability.
* Regularization during finetuning by setting a smaller dropout rate (0.1) at non-expert layers and a much larger dropout rate (0.4) at expert layers to mitigate overfitting.
**Parallel computation**  A combination of data, model, and expert parallelism can be employed to balance the FLOPs per token, communication costs, and memory per core.

***Data and weight partitioning strategies. The 4x4 dotted-line grid represents 16 cores and the shaded squares are the data contained on that core (either model weights or batch of tokens). Image is from the Switch Transformer paper.***
The impressive scaling properties of Switch Transformers is demonstrated through the pretraining performance on C4, a large corpus with over 180B target tokens (figures below are from the Switch Transformer paper):
* By increasing the number of experts, the FLOPs per token stays constant, and the test loss for a fixed number of steps consistently decreases.

* By comparing the sample efficiency of a dense model variant and four FLOP-matched sparse variants, it is observed that increasing the number of experts leads to more sample efficient models. In particular, Switch-Base 64 expert model achieves at step 60k the same performance of the T5-Base model at step 450k, which is a 7.5x speedup in terms of step time.

* While the compared models have roughly the same amount of FLOPs per token, extra experts incur additional communication costs across devices and the extra computation of the routing mechanism. Therefore, the increased sample efficiency observed on a step basis doesn't necessarily translate to a better model quality as measured by wall-clock time. For a fixed training duration and computational budget, Switch Transformers still yield a substantial speed-up. Switch-Base 64 expert model trains in one-seventh the time that would take the T5-Base to get similar perplexity.

## Capacity Factor and Communication costs
The capacity factor is a critical parameter in MoE models that determines the maximum number of tokens that can be processed by an expert. By setting an appropriate capacity factor, the model can balance the computational load across experts, preventing bottlenecks and ensuring efficient resource utilization. This section delves into the concept of the capacity factor and its impact on communication costs in MoE models.
The capacity factor in MoE models represents the maximum number of tokens that an expert can process effectively. From the paper of Switch Transformer,
The expert capacity—the number of tokens each expert computes—is set by evenly dividing the number of tokens in the batch across the number of experts, and then furhter expanding by a capacity factor.
$$
\text{expert capacity} = \frac {\text{tokens per batch}} { \text{number of experts}} \times \text{capacity factor}
$$
A capacity factor greater than 1.0 creates additional buffer to accommodate for when tokens are not perfectly balanced across experts. If too many tokens are routed to an expert (referred to later as dropped tokens), computation is skipped and the token representation is passed directly to the next layer through the residual connection. Increasing the expert capacity is not without drawbacks, however, since high values will result in wasted computation and memory.
## Open-Source MoE Models
The table below lists recent popular MoE models with open-source code:
| MoE Model|Task | Link to Code |
| -------- |---- |------- |
|**Switch Transformers**|translation|[code](https://huggingface.co/docs/transformers/en/model_doc/switch_transformers)|
|NLLB-MoE|translation|[code](https://huggingface.co/facebook/nllb-moe-54b)|
|**Mixtral**|text generation|[code](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)|
|BlackMamba|text generation|[code](https://huggingface.co/papers/2402.01771)|
|Grok-1|text generation|[code](https://huggingface.co/xai-org/grok-1)|
|DBRX|text generation|[code](https://huggingface.co/databricks/dbrx-instruct)|
|Qwen|text generation|[code](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B)|
|DeepSeek-V2|text generation|[code](https://huggingface.co/deepseek-ai/DeepSeek-V2)|
## Future Work
The future of MoE models is promising, with ongoing research focusing on enhancing their performance and efficiency. Some key areas of interest include:
* Scalability: Developing techniques to scale MoE models to handle even larger datasets and more complex tasks.
* Interpretability: Improving the interpretability of MoE models to understand how decisions are made by the experts and the gating network.
* Quantization: Exploring quantization techniques to reduce the memory and computational requirements of MoE models while maintaining performance.
* Transfer Learning: Investigating transfer learning methods to leverage pretrained MoE models for a wide range of tasks and domains.
* Efficient Training: Developing efficient training methodologies for MoE models to reduce training time and resource consumption.
By addressing these challenges and opportunities, researchers and practitioners can unlock the full potential of MoE models and leverage their capabilities to solve a wide range of machine learning tasks.
## Conclusions
Mixture of Experts (MoE) models represent a significant advancement in neural network architecture, offering a powerful approach to handling complex tasks. By combining multiple expert networks with a gating mechanism, MoE models can leverage specialized knowledge and improve performance. The architecture and training methodologies of MoE models have been refined over time, leading to notable advancements in natural language processing, computer vision, and other domains. As MoE models continue to evolve, they hold great promise for enhancing model efficiency, scalability, and interpretability, making them a valuable tool for machine learning practitioners.
# References
## code
* https://github.com/hkproj/mistral-src-commented/blob/main/mistral/cache.py
* https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py
## Images
* https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/
* https://afmck.in/posts/2023-02-26-parallelism/
* https://unsplash.com/?utm_source=ghost&utm_medium=referral&utm_campaign=api-credit
* https://cameronrwolfe.substack.com/p/conditional-computation-the-birth#footnote-anchor-1-142423094
* https://cameronrwolfe.substack.com/p/conditional-computation-the-birth
* https://medium.com/@gopalgoyal612002/mistral-llm-architectural-details-8dc0447fea62
* https://arxiv.org/abs/2310.06825
* https://research.google/blog/introducing-gpipe-an-open-source-library-for-efficiently-training-large-scale-neural-network-models/
* https://arxiv.org/abs/2006.16668
* https://arxiv.org/abs/2101.03961
## Blogs and Papers
* https://arxiv.org/abs/2101.03961
* https://arxiv.org/abs/2006.16668
* https://arxiv.org/pdf/1701.06538
* https://arxiv.org/abs/2310.06825
* https://arxiv.org/pdf/2310.01801
* https://arxiv.org/abs/2004.05150v2
* https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf
* https://mistral.ai/news/announcing-mistral-7b/
* https://mistral.ai/
* https://cameronrwolfe.substack.com/p/conditional-computation-the-birth
* https://kipp.ly/transformer-inference-arithmetic/#kv-cache
* https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/
* https://afmck.in/posts/2023-02-26-parallelism/
* https://research.google/blog/introducing-gpipe-an-open-source-library-for-efficiently-training-large-scale-neural-network-models/
* https://huggingface.co/transformers/v4.10.1/parallelism.html
* https://timdettmers.com/2014/11/09/model-parallelism-deep-learning/
* https://huggingface.co/blog/moe#a-brief-history-of-moes
* https://arxiv.org/abs/2401.04088
## Appendix
### Code of Model Sharding
Here we only includes the code related to model sharding for easier explaination. This sharding approach allows for efficient handling of larger models by distributing the computational load across multiple GPUs, enabling parallel processing and reducing memory constraints.
```python!
class Transformer(nn.Module):
def __init__(self, args: ModelArgs, pipeline_rank: int = 0, num_pipeline_ranks: int = 1):
super().__init__()
self.args = args
self.pipeline_rank = pipeline_rank
self.num_pipeline_ranks = num_pipeline_ranks
# Define the number of layers and assign layers to ranks
layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
offset = self.pipeline_rank * num_layers_per_rank
end = min(self.n_layers, offset + num_layers_per_rank)
self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
self.n_local_layers = len(self.layers)
...
def forward(self, input_ids: torch.Tensor, seqlens: List[int], cache: Optional[RotatingBufferCache] = None) -> torch.Tensor:
h = self.forward_partial(input_ids, seqlens, cache=cache)
if self.pipeline_rank < self.num_pipeline_ranks - 1:
# Ignore the intermediate activations as we'll get the final output from the last stage
outs = torch.empty(h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype)
else:
assert self.output is not None
outs = self.output(h) # Apply the output linear projection of the embeddings to the vocabulary size
if self.num_pipeline_ranks > 1:
torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
return outs.float()
```
In this code snippet, The specific rank of the current model instance in the pipeline(***pipeline_rank***), the total number of ranks (GPUs) in the pipeline(***num_pipeline_ranks***), the number of layers each rank (GPU) will handle(***num_layers_per_rank***) and the number of layers handled by the current rank(***self.n_local_layers***) are defined in the initialization part.
In the ***forward*** function:
1. the ***forward_partial*** method processes the input through the layers assigned to the current rank(GPU) and returns the intermediate activations.
2. If the current rank(GPU) is not the last in the pipeline (self.pipeline_rank < self.num_pipeline_ranks - 1), an empty tensor outs is created to hold the final output. The final rank applies the output linear projection (self.output) to convert the embeddings to the vocabulary size.
3. If there are multiple ranks, the final output (outs) is broadcasted from the last rank to ensure all ranks have access to the final output tensor.
4. The method returns the final output tensor as a floating-point tensor.
By implementing model sharding with pipeline parallelism, this ***Transformer*** class can efficiently train and infer on large models that would otherwise be difficult to manage on a single GPU.
### Code Explanation for KV-Cache
We will explain some of the code of KV-Cache in this section to provide more details to help the understanding.
The code below is from: https://github.com/hkproj/mistral-src-commented/blob/main/mistral/cache.py.
Updating the cache:
```python!
def update(self, xk: torch.Tensor, xv: torch.Tensor):
"""
to_cache_mask masks the last [sliding_window] tokens in each sequence
"""
n_kv_heads, head_dim = self.cache_k.shape[-2:]
flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim) # (Max_Batch_Size, Sliding_Window_Size, N_Heads_KV, Head_Dim) --> (Max_Batch_Size * Sliding_Window_Size, N_Heads_KV, Head_Dim)
flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim) # (Max_Batch_Size, Sliding_Window_Size, N_Heads_KV, Head_Dim) --> (Max_Batch_Size * Sliding_Window_Size, N_Heads_KV, Head_Dim)
# Copies from the xk and xv tensors to the cache tensors, based on the cache positions and the items to cache (to_cache_mask)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])
```
Here **self.metadata.to_cache_mask** is a boolean mask indicating which elements in the input tensors xk and xv should be cached. It typically masks the last [sliding_window] tokens in each sequence. (xv: A tensor representing the new keys to be added to the cache.
xv: A tensor representing the new values to be added to the cache). So in this line, it is selecting the elements needs to be cached based on the boolean mask in 'xk' and 'xv'.
Interleaving the KV-Cache:
```python!
def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a naive implementation and not optimized for speed.
"""
assert xk.ndim == xv.ndim == 3 # (B * T, H, D)
assert xk.shape == xv.shape
if all([s == 0 for s in self.metadata.seqlens]):
# No cache to interleave
return xk, xv
# Make it a list of [(Seq, N_Heads_KV, Head_Dim)]
xk = torch.split(xk, self.metadata.seqlens) # (Seq1+Seq2+Seq3, N_Heads_KV, Head_Dim) --> [(Seq1, N_Heads_KV, Head_Dim), (Seq2, N_Heads_KV, Head_Dim), (Seq3, N_Heads_KV, Head_Dim)]
xv = torch.split(xv, self.metadata.seqlens) # (Seq1+Seq2+Seq3, N_Heads_KV, Head_Dim) --> [(Seq1, N_Heads_KV, Head_Dim), (Seq2, N_Heads_KV, Head_Dim), (Seq3, N_Heads_KV, Head_Dim)]
assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
# Order elements in cache by position by unrotating
cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)] # Currently cached elements, already unrotated, one for each prompt
cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)] # Currently cached elements, already unrotated, one for each prompt
interleaved_k = interleave_list(cache_k, xk) # Appends the incoming keys and values to the currently cached elements (one for each prompt)
interleaved_v = interleave_list(cache_v, xv) # Appends the incoming keys and values to the currently cached elements (one for each prompt)
return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)
```
In this code snippet we could see that the incoming keys and values are appended to the currently cached elements. By doing this, the function efficiently updates the cache with the latest information without discarding existing valuable context. This maintains a balance between old and new information.
Retrieving a value from the KV-Cache:
```python!
@property
def key(self) -> torch.Tensor:
return self.cache_k[:len(self.kv_seqlens)]
@property
def value(self) -> torch.Tensor:
return self.cache_v[:len(self.kv_seqlens)]
```
Retrieving a value from the key-value cache in this code snnipet has a time complexity of O(1), making it an efficient operation. This efficiency is due to the nature of tensor slicing in PyTorch, which does not involve copying data but instead creates a new view on the existing data.
From this code snnipet, it is vey clear that retrieving a value from KV-Cache is a very efficient way. It could save more time and space than retrieving the value from a slower data storage.
### Code for Specialized Loss Functions in MoEs
Below is a simplified example of how the router Z-loss function can be implemented in PyTorch:
```python
import torch
import torch.nn as nn
# Define the router Z-loss function
def router_z_loss(logits): # logits is the input to the gating network, ranging from -inf to inf
return torch.mean(torch.log(torch.sum(torch.exp(logits), dim=1) ** 2))
# Generate sample data
logits = torch.randn(3, 5, requires_grad=True) # assuming a batch of 3 samples and 5 experts
# Calculate the loss
loss = router_z_loss(logits)
print(loss)
```
, where [the original code snippe](https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py)t is from the paper [ST-MOE: Designing Stable And Transferable Sparse Expert Models](https://arxiv.org/pdf/2202.08906.pdf) and shown below for your reference:
```python
def _router_z_loss(logits, experts_dim, num_microbatches, importance=None):
"""Loss that encourages router logits to remain small and improves stability.
Args:
logits: a tensor with shape [<batch_dims>, experts_dim]
experts_dim: a Dimension (the number of experts)
num_microbatches: number of microbatches
importance: an optional tensor with shape [<batch_dims>, group_size_dim]
Returns:
z_loss: scalar loss only applied by non-padded tokens and normalized by
num_microbatches.
"""
log_z = mtf.reduce_logsumexp(logits, experts_dim)
z_loss = mtf.square(log_z)
if importance is not None:
z_loss *= mtf.cast(mtf.equal(importance, 1.0), dtype=z_loss.dtype)
denom = mtf.reduce_sum(
mtf.cast(mtf.equal(importance, 1.0), dtype=z_loss.dtype))
z_loss = mtf.reduce_sum(z_loss) / (denom * num_microbatches)
return z_loss
```