# Fused Attention Kernels from Scratch $$ \newcommand{\R}{\mathbb{R}} \newcommand{\Lin}{\text{Lin}} \newcommand{\trace}{\text{Tr}} \newcommand{\der}{\partial} \newcommand{\st}[2]{\stackrel{#1}{#2}} \newcommand{\stK}[1]{\stackrel{#1}{K}} \newcommand{\stQ}[1]{\st{#1}{Q}} \newcommand{\stP}[1]{\st{#1}{P}} \newcommand{\sf}{\text{sf}} $$ Note: This post heavily relies on concepts and notation explained in previous posts of ours. We recoomned reading: * Linear Transformers post * Beautiful Derivatives post: to understnd the notation and techniques used throughout * Kernel Fusion with Pallas: It explains the principles of kernel fustion. And introduciton to pallas. And how to write backwards kernels and plug them into the autodiff system of jax. To understand why the kernel fusion techniques make sense for an attention layer, let's start by taking a closer look at a naive implementation of a linear attention layer. Later we will also look at the standard attention with a softmax. $$ A = \left( Q K^T \right) \odot M \;\;\;\;\;\; Y = A V $$ From this we would get the following naive implementation FIGURE In jax it would correspond to the following code ```python! import jax.numpy as jnp def reference_attention(Q, K, V): B = (Q @ K.T) * mask(Q.shape[0]) s = B.sum(axis=0, keepdims=True) A = B/s Y = A @ V return Y t, d = 128, 16 Q, K, V = [jnp.ones([t, d])]*3 attention(Q, K, V) ``` The 3 steps of computation do this total amount of reading from memory $$ 2td + t^2 + (t^2+td) \simeq O(t^2 + td) $$ and if you look at the writes you see that it is also $O(t^2 + td)$. This is really bad because $d$ for a large model often is 64 or 128, but $t$ will certainly be > 1k, more likely 8k or even 32k. This naive implementation will hit 2 key problems. **memory allocation** Let's look at the size of a single $t\times t$ array for $t=32k$. Thus, every matrix will have around 1B entries. But the GPT2 XL has 25 heads an each layer, and each will have their own matrix. So the size of the object would be 25B, at 2 bytes per dimension, we already have 50GB of our 80 available gone to hold the object. Since we also have to store weights, optimizer states, other activations and gradients, we wouldn't even have 50GB available and we would get OOM errors even with batch size 1. All the experiments we've done in past blog posts would have been impossible. **memory accesses** But even if we reduced $t$ to something more managable like $4k$, reading and writing the $t^2$ sized object would end up taking a major portion of all the time is spend computing the entire architecture. In this past blogpost you can see more data comparing the speed of variaous implementaitons. Writing fused kernels will solve both problems at once. ## Fused Forward Pass The crucial idea that motivates writing fused kernels is that we don't need to ever store the matrix $A$ or $B$. Both matrices have size $t^2$ but can be computed out of $Q,K$ wich have size `[t,d]` each. Whenever we need to access an entrie of the matrix, we can just compute the value instead of reading it. This will avoid the allocation of t^2 memroy. To compute the outputs we can use the equation $$Y_i = \sum_j A_{ij} V_j$$ Can we compute $A_{ij}$ on the fly? Let's see. Let's assume we are doing the simplex. Then $A_{ij} = \frac{B_{ij}}{\sum_k B_{ik}}$. The $B_{ij}$ part is easy, we just need to load $Q_i, K_j$ and do a dot product. The problem is the normalizatoin term, but that is fine, we can loop over $j$, computing the unnormalized output $$\bar Y_i = \sum_j B_{ij} V_j$$ Since we are computing all the $B_{ij}$ terms, we can also accumulate them $s_i = \sum_j B_{ij}$. Finally, we can just do $Y_i = \frac{\bar Y_i}{s_i}$. A similar trick works for the softmax. The following pallas kernel is an implementation of this idea ```python! def attention_v1_kernel(Q_ref, K_ref, V_ref, Out_ref, name): t, d = Q_ref.shape i = pl.id() Y_i = jnp.zeros([d]) s_i = 0. Q_i = Q_ref[i] # d reads for j in range(t): K_j, V_j = K_ref[j], V_ref[j] # 2d reads QKT_ij = jnp.dot(Q_i, K_j) # inner product of two vectors M_ij = (j <= i) B_ij = QKT_ij * M_ij Y_i += B_ij*V_j s_i += B_ij Y_i = Y_i/s_i Out_ref['Y'][i] = Y_i # d writes # We also store the normalizations # we will use them in the backward pass Out_ref['s'][i] = s_acc # write 1 thing def attention_v1(Q, K, V): t = Q.shape[0] kernel = pl.pallas_call( attention_v1_kernel, grid=(t,), out_shape = {'Y': jax.ShapeDtypeStruct((t, d),Q.dtype)), 's': jax.ShapeDtypeStruct((t,),Q.dtype)) Out = kernel(Q, K, V) Y, s = Out['Y'], Out['s'] return Y, s ``` The kernel has a grid of size `[t]`. Essentially, this means that there will be `t` independent programs running on different cores of the GPU. On the main memory, this kernel doesn't really allocate anything except for the output, which is of size `[t,d]`. But let's analyze the memory accesss. At the beginning, each of the `t` programs loads `d` things. And at each step of the inner loop it loads `2d` thigs. Finally, it writes `d` thigs. Thus, the total amount of memory acesses is $t(d + t2d + d)= 2td(1+t)=O(dt^2)$. Thus, even though we don't allocate t^2 things, we still do a proportional amount of reads and writes. It turns out that is the performance bottleneck more than the Do off the shelf estimates of FLOPs and bandwidth of H100 GPU and show how we should be bounded by reads and writes more than FLOPs. Note that we aren't doing any matmuls, so we aren't using tensor cores. So the GPU flops are way lower than the theoretical maximum. ## Block Kernel This second kernel will not only avoid allocating t^2 memory. It will perform way less memory accesses! The idea is comes from a block matrix formulation of attention FIGURE Where we have broken down the attention matrix $A$ into blocks of size `Bq` by `Bk`. Then, the outputs are `t//Bq` blocks, each of size `[Bq, d]` So we can break down the computation of the ```python! Bq, Bk = 128, 16 def attention_v2_kernel(Q_ref, K_ref, V_ref, Y_ref, name): t, d = Q_ref.shape i = pl.id() acc = jnp.zeros([d]) s_acc = 0. Q_i = Q_ref[i*Bq:(i+1)*Bq] # read Bq*d things for j in range(t//Bk): K_j = K_ref[j*Bk:(j+1)*Bk] # read Bq*d things V_j = V_ref[j*Bk:(j+1)*Bk] # read Bq*d things QKT_ij = Q_i @ K_j.T # uses the tensor cores! M_ij = jnp.arange(i, i+Bq) <= jnp.arange(j, j+Bk) B_ij = QKT_ij * M_ij if name == 'Simplex': acc += B_ij @ V_j s_acc += B_ij.sum(axis=1) elif name == 'Softmax': z = exp(B_ij) acc += z @ V_j s_acc += z.sum(axis=1) Y_ref[i*Bq:(i+1)*Bq] = acc/s_acc # write d things def attention_v1(Q, K, V): t = Q.shape[0] kernel = pl.pallas_call( attention_v1_kernel, grid=(t//Bq,), out_shape = jax.ShapeDtypeStruct((t, d),Q.dtype)) Y = kernel(Q, K, V) ``` Ok, let's analysie the memory accesses. We have `t//Bq` programs. At first, each loads `Bq*d` things. We unroll the loop for `t//Bk` and in each we load `2*Bk*d`. Finally it writes `Bq*d` thigs. So the total is `t//Bq * (Bq*d + t//Bk*2*Bk*d + Bq*d)` $$ \frac{t}{B_q} (B_q d + \frac{t}{B_k}(2B_k d) + B_q d) = 2td + \frac{2t^2d}{B_q} = O(\frac{t^2 d}{B_q}) $$ How come we load much less things into the cores? This is because we compute attention blocks `A_ij` out of Q and K blocks, of sizes `[Bq, d]` and `[Bk, d]` resp. But before, to compute all the `Bq * Bk` elemtns of `A_ij` we would have needed to load a query and key vector for each, so the total loads would have been `Bq * Bk * 2* d` as opposed to `Bq * d + Bk * d` If all the gains come from a large `Bq`, why don't we set it to the largest value `Bq=t`? There are many factors involving the optimal values. Some performance considerations: * The GPU cores have very little memroy. 200kb for the H100. So the maximum we could afford is to set it to X. But there are other considerations. * Pallas throws an error if we do a matmul with a shape less than 16, tensor cores can only be used. So if we want to use the full power of our GPUs is worth to expend some memory in the Bk dimension so it's at least >16 and the tensor cores can be used * GPUs hide latency by running multiple things at the same time on the same core. So that when one is waiting for the slow HBM transfer, the other can execute computations. But for that to be possible it's necessary that each programm to use a fraction of the local core memory. So we don't want to set Bq so high that only 1 program can be scheduled at once. * There is overhead into every read and write operations. That means it's more efficient to do a single large read than many small ones. That is relevant inside the for loop. If Bk is very small, the loop will have many steps and each will incurr this overhead. Theoretically, the total amount read will be the same, but it will be slower. And you can be sure that a GPU expert would be able to give 15 extra reasons it's not so easy to pick the optimal values. In the end, you just want to benchmark a few configurations and pick the fastest ones. NOTES: This kernel is actually pretty bad. For example, it uses a python for loop instead of `lax.for_i` loops. This is really bad because the compilation times will grow with `t//Bk`. There are a few extra optimizations that we can still perform, but they are minor. We will revisit all the little tricks to implement a good kerenl at the end of the article. But for now let's move on to the real challange. Implementing a fused backards kernel. ## Gradients wrt $V$ and $A$ Alright, we are more than ready to compute the gradients wrt $Q,K,V$. Just like in the last section, we presume that we know $\nabla_Y l$, the gradient of the loss $l$ wrt the transoformer outputs $Y$. That of course means that the derivative will be $$ \frac{\der l}{\der Y} (\dot Y) = \trace(\nabla_Y l^T \dot Y) $$ Let's start by computing $\nabla_V l$. From the chain rule, it's clear that $$ \frac{\der l}{\der V} (\dot V) = \frac{\der l}{\der Y} \left(\frac{\der Y}{\der V} (\dot V) \right) = \trace \left(\nabla_Y l^T \frac{\der Y}{\der V} (\dot V) \right) $$ Since $Y = AV$, which is a linear funciton wrt either $A$ or $V$, it should be clear that $\frac{\der Y}{\der V} (\dot V) = A \dot V$, then $$ \frac{\der l}{\der V} (\dot V) = \trace \left(\nabla_Y l^T A \dot V \right) = <A^T \nabla_Y l,\dot V> $$ and so, $\nabla_V l = A^T \nabla_Y l$. TODO: DERIVATION $$ \nabla_A l = \nabla_Y l V^T $$ Look at how easy that was. No confusing indices. Just a few mechanical steps ## Gradients wrt $B$ Now that we understand the how the preattention affects the attention, we are ready to incorporates the results from the last section to compute the gradients of the loss wrt to the preattention The general formula would be $$\frac{\der l}{\der B_i} = \sum_j \frac{\der l}{\der A_j} \frac{\der A_j}{\der B_i} $$ but the if $i\neq j$ then $\frac{\der A_j}{\der B_i}=0$. This simplifies a lot the derivative but unfortunately makes us do a little bit of the index wrangling we've been trying to avoid. It will be nicer to compute the gradient $\nabla_{B_i} l$ (gradient of $l$ wrt to the vector $B_i = K Q_i$). Then we will put back together the gradient matrix $\nabla_{B} l$ which of course is the matrix s.t. if we index at $i$ we get $\nabla_{B_i} l$. \begin{align} \frac{\der l}{\der B_i} &= \frac{\der l}{\der A_i} \der \alpha(B_i) \\ &= \nabla_{A_i} l^T\alpha(B_i)\left( I - A_i u^T \right) \\ &= \alpha(B_i)\big( \nabla_{A_i} l^T -\underbrace{ \nabla_{A_i} l^T A_i}_{d_i} u^T \big) \\ \end{align} So $$ \nabla_{B_i} l = \alpha(B_i)\left( \nabla_{A_i} l - \; u \right) $$ It's useful to define the intermediate variable $d_i\in \R$ and the corresponding vector $d\in \R^t$ because then we can reconstruct an expression for the full gradient $$ \nabla_B l = D_{\alpha(B)}\left( \nabla_A l - d u^T \right) $$ Where $D_x$ denoes the diagonal matix with $x$ entries, $\alpha(B)\in \R^t$ denotes the vector with entries $\alpha(B_i)\in \R$. Note how this last equation uses two different types of broadcasting. The diagonal matmul multiplies each row by a different scalar and $-d u^T \in \R^{t\times t}$ subtracts a copy of $d$ on each column. Frankly, this particular gradient is one of the rare cases where the code looks nicer than the equations ```python! norms = vmap(alpha)(B) r = vmap(lambda x,y: jnp.dot(x,y))(A, A_grad) B_grad = norms[:, None] * (A_grad - r[:, None]) ``` ## Gradients wrt $Q$ and $K$ Now that we have the gradients wrt the preattention $B$ we have all we need to compute the gradients wrt $Q$ and $K$. It's just a trival applications of the techniques of matrix derivatives we've been using $$ \frac{\der l}{\der Q}(\dot Q) = \trace\left(\nabla_B l^T \frac{\der B}{\der Q}(\dot Q) \right) = \trace(\nabla_B l^T \dot Q K^T) = \trace( K^T \nabla_B l^T \dot Q) $$ From there, we can read off the gradinet: $$ \nabla_Q l = \nabla_B l K $$ Similarly, $$ \frac{\der l}{\der K}(\dot K) = \trace(\nabla_B l^T Q \dot K^T) = \trace(Q^T \nabla_B l \dot K) $$ where we used the fact that $\trace(A) = \trace(A^T)$. And so the gradient is: $$ \nabla_K l = \nabla_B l^T Q $$ ## Numerical Check of our Gradient Equations We can implement all of the equations in JAX and compare them with the outputs of the autodiff system. ```python! def reference_attention_gradient(Q, K, V, Y_grad): V_grad = A.T @ Y_grad A_grad = Y_grad @ V.T r = vmap(lambda x,y: jnp.dot(x,y))(A, A_grad) B_grad = (A_grad - r[:, None]) / s[:, None] Q_grad = B_grad @ K K_grad = B_grad.T @ Q return Q_grad, K_grad, V_grad def loss(Y): # could be anything, as long as the output is a scalar Y.sum()**2 # Compute the gradients manually Y = reference_attention(Q, K, V) Y_grad = jax.grad(loss)(Y) Q_grad, K_grad, V_grad = reference_attention_gradient(Q, K, V, Y_grad) # Compute the gradients with jax autodiff QKV_grad_fn = jax.grad(lambda *args: loss(reference_attention(*args))) _Q_grad, _K_grad, _V_grad = QKV_grad_fn(Q, K, V) assert jnp.allclose(Q_grad, _Q_grad) assert jnp.allclose(K_grad, _K_grad) assert jnp.allclose(V_grad, _V_grad) ``` Alright, so our equations seem to be right. But to train our models, we can't manually be calling the gradient function. We want to use our fused kernel like any other funciton and then be able to call `jax.grad` on the final loss of our model. To do that, we need to setup the interface with the jax autodiff system. ## Fused Backward Pass Alright, now let's apply the kernel fusion philosophy. The only things we want to store in HBM are the gradients wrt $Q,K,V$, everything else we will compute on the fly as we need it. ## Some Optimizations * Optimize compilation times. for_i * The loops don't need to do the entire range * manage dtypes properly * do some benchmarking sweeps ## Attention with Softmax The only thing that changes when we use a softmax is $A_i = \sf(B_i) = f(\exp(B_i))$. This means, in the forward pass kernel, we only need to change 2 lines For the backward pass, we will only need to change the part concerned with the derivative $\frac{\der A_i}{B_i}$. Let's just work it out quickly. The jacobian $\der\sf(x)$ is... And with that we can get the gradient of the loss wrt to $B$ with a similar derivation to before. \begin{align} \frac{\der l}{\der B_i} &= \frac{\der l}{\der A_i} \der \sf(B_i) \\ &= \nabla_{A_i} l^T \; \left (D_{A_i} - A_i A_i^T \right ) \\ &= \left (\nabla_{A_i} l^T D_{A_i} - d_i A_i^T \right ) \end{align} Since $x^T D_{y} = (x \odot y)^T$, we can write the gradient wrt $B_i$ as $$\nabla_{B_i} l = \nabla_{A_i} l \odot A_i \; - d_i A_i = (\nabla_{A_i} l - d_i u) \odot A_i$$ From this we can reconstruct the full gradient wrt $B$ $$ \nabla_B l = (\nabla_{A} l - d u^T) \odot A $$ We just need to modify a signle line in each of the kernels for $Q_grad$ and $K_grad$. Well, that was easy. But not really, because that kernel would have big numerical stability issues. The exponential can quickly blow up and become a nan, specially when we use the low precision formats that are so important to speed up deep learning. The standard trick, is to exploit a property of the softmax. For all $r\in \R$ and $x\in \R^n$ $$ \sf(x + r \mathbb{1}) = \sf(x) $$ Which you can easily verify by yourself. This property is useful because we actually don't need to evaluate the expoenntial on large numbers. We can just define $\bar x = x - (\max x) \cdot x$ and then just call $\sf(\bar x)$. But appling this idea to The trick to solve this was first proposed in [CITE], a predecesor of flash attention. It was just concerned with avoiding the $t^2$ memory allocation. They poposed the idea of keeping a running max and using the following equations and the $Y_i$ term is accumulated. GIVE EQUATIONS Does this have any impact on the backward pass? Let's finally implement a version of a complete, functional fused kernel for attention with softmax ```python! ``` And there you go. Now you also got an implementation for linear and softmax attention layers. But more importantly, hopefully you now have the ability to write your own fused attention kernels for any modification that you might come up with. Free from the constraints of using prebundled kernels or the slowness of pure jax implementations. ## Problems for the Reader If you want to test your knowledge about how to write fused attention kernels? No better way than to implement something by yourself. Here are some good exercises: * Fuse together k_grad and v_grad * Avoid unnecesary masking computations: we don't need to do masking every step of the inner loop. It's enough to do it only towards the end, and only compute the mask $B_q / B_k$ times. * Implement sphere norm. * Implement second order taylor approx to exp. * Implement random features kernel.