$$
\newcommand{\R}{\mathbb{R}}
\newcommand{\Lin}{\text{Lin}}
\newcommand{\trace}{\text{Tr}}
\newcommand{\der}{\partial}
\newcommand{\st}[2]{\stackrel{#1}{#2}}
\newcommand{\stK}[1]{\stackrel{#1}{K}}
\newcommand{\stQ}[1]{\st{#1}{Q}}
\newcommand{\stP}[1]{\st{#1}{P}}
\newcommand{\sf}{\text{sf}}
$$
# Transformer Forward and Backward Passes
The name has two components. First component tells you how the preattention is computed, linear or multilinear. The second component specifies how the attention is computed, simplex, sphere or softmax. The standard transformer would be `LinearSoftmax`, our next release `MultilinearSimplex`. The following python script contains the computations for the forward and backward pass for any name constructed this way.
```python!
# name = ['Linear', 'Simplex'] # In our first release
# name = ['Linear', 'Sphere'] # Allows rotary embeddings but not much point in exploring since state dim is too small
name = ['Linear', 'Softmax'] # A normal transformer
#name = ['Multilinear', 'Simplex'] # We will present this in our next big release
# name = ['Multilinear', 'Sphere'] # This is the architecture we want to use long term. Behaves very nicely with gating and rotary encoders
# name = ['Multilinear', 'Softmax'] # We probably want to use it as a baseline. Does the multilinear trick matter when the state space is already infinite?
# Forward Pass: we assume Q, K, V are [t, d] arrays
# Start by computing the preattention B
if name[0] == 'Linear':
B = Q @ K.T
elif name[0] == 'Multilinear':
_Q = Q.reshape([t, p, d//p]).transpose(1, 0, 2)
_K = K.reshape([t, p, d//p]).transpose(1, 0, 2)
_QKT = _Q @ _K.transpose(0, 2, 1)
_B = _QKT # broadcast sum along dim 0
B = _B.prod(axis=0)
# Compute A, the attention matrix
if name[1] = 'Simplex':
s = A.sum(axis=0, keepdims=True)
A = A/s
elif name[1] = 'Sphere':
s = sqrt((A.pow(2)).sum(axis=0, keepdims=True))
A = A/s
elif name[1] = 'Softmax':
A = softmax(A)
# Backward Pass: Y_grad is a [t, d] array
# first compute V and A gradients
V_grad = A.T @ Y_grad
A_grad = Y_grad @ V.T
# compute gradients wrt preattention B
d = vmap(lambda y, y_grad: dot(y, y_grad))(Y, Y_grad)
if name[1] = 'Simplex':
B_grad = (A_grad - d[:, None]) / s[:, None]
elif name[1] = 'Sphere':
B_grad = (A_grad - d[:, None]) / s[:, None]
elif name[1] = 'Softmax':
B_grad = (A_grad - d[:, None]) * A
# compute Q and K gradients
if name[0] = 'Linear':
Q_grad = B_grad @ K
K_grad = B_grad.T @ Q
elif name[0] = 'Multiinear':
_U = (B_grad * B)[None, :] / _KQT
_Q_grad = _U @ K
_K_grad = _U.T @ Q
Q_grad = _Q_grad.reshape([t, d])
K_grad = _K_grad.reshape([t, d])
```
The rest of the document cosist of the derivations of all the equations used in this script.
## Notation
The derivative notation is the following. Given a function $f: \R^n \to \R^m$, the derivative $\partial f: \R^n \to \text{Lin}(\R^n, \R^m)$ is a funciton that takes a point in the input space of $f$ and outputs a linear map that approixmates the change of $f$ near that point.
$$f(x + v) \simeq f(x) + \partial f(x)(v)$$
In other words, $\partial f(x)$ is the jacobian of $f$ at the piont $x$.
To talk about a change to $x$ we could use any symbol (above I used $v$) but it's nice to just write $\dot x$. This way it's evident form the notation that $\dot x$ has the same shape as $x$. But the dot does not carry any extra meaning beyond that.
Sometimes, when working out jacobians its convinient to first write the equations for the jacobian applied to an arbitrary change vector. So instead of finding the equation for the matrix $\partial f(x)$ we can find the equation for $\partial f(x)(\dot x)$. From that, it is usually very easy to extract the equation for the full jacobian $\partial f(x)$.
The reason sometimes it's more convinient to compute $\partial f(x)(\dot x)$ is that it's very easy to break it down into derivatives wrt dim 1 variables. If $x\in \R^n$
$$
\partial f(x)(\dot x) = \sum_{i=1}^n \frac{\partial f(x)}{\partial x_i} \dot x_i
$$
Until now we've been talking about notation for derivatives. Now let's have some words about the computation of gradients. The first thing to say about gradients is that they are a much less general object than derivatives. We can only take the gradient of a funciton $f:\R^n \to \R$ with scalar outputs.
Let's define the gradient. Say that $f:V \to \R$ is a differentiable function from a vector space $V$ to $\R$. We also need $V$ to have an inner product $<>$. Then the gradient at $x\in V$ is defined as the unique vector $\nabla f(x) \in V$ such that
$$
<\nabla f(x), \dot x> = \partial f(x)(\dot x)
$$
NOTE: The fact that such unique vector always exists is a simple application of the Riesz representation theorem.
A very important property of gradients of $f$ is that they have the "same type" as the inputs to $f$. That is what allows defining dynamical systems where the input changes along the direciton of the gradient, as we do in deep learning.
The reason we defined the gradient for a generic vector space is that we will need to consider gradients for functions where the inputs are matrices, not just vectors. Of course we could flatten the matrices into vectors, but that doesn't turn out to be the most elegegant way of handling things.
**First case $V = \R^n$ and $<v, w> = v^T w$**. This is probably the most common type of gradient one computes in deep learning. Since $f: \R^n \to \R$ and $\partial f(x) \in \text{Lin}(\R^n, \R)$ then we can think of $\partial f(x) \in \R^{1\times n}$ as a $1$ by $n$ matrix. So that $partial f(x)(v) = \partial f(x) v$ (we remove the parenthesis to indicate that we went form a generic linear funciton to matrix notation). Then we can easily see what the gradient is
$$
<\nabla f(x), v> = \nabla f(x)^T v = \partial f(x) (v) = f(x) v
$$
So that clearly $\nabla f(x) = \partial f(x)^T$. It makes sense since a $1\times n$ matrix transposed is a $n$ vector like the inputs to $f$, which is what we expect the gradinet to be.
**Second case $V = \R^{n\times m}$ and $<A, B> = \trace(A^T B)$**. If you are never seen it before, you might be surprize to see this inner product between matrices that involves the trace. You should verify that it is indeed an inner prodcut. In fact, it is equivalent to the more "natural feeling" inner product $<A, B> = \sum_{ij} A_{ij} B_{ij}$, but we use the trace formulation because its much more elegant and has very useful algebraic properties that we will use later.
Note that since the inputs to $f$ are matrices, the derivative of $\partial f(A) \in \Lin(\R^{n \times m}, \R)$ isn't really a matrix, it's just a linear function. So there really isn't a clear sense in which you can say that the gadient is the derivative transposed. You just need to take the linear funciton $\partial f(A)(B)$ in whatever form is written and find the unique $\nabla f(A) \in V$ satisfying $<\nabla f(A), B> = \partial f(A)(B)$ for all $B\in V$.
A simple example will ilustrate the point. Imagine that $\partial f(A)(B) = v^T B w$ for $v\in \R^n, \; w\in \R^m$. Then $\partial f(A) (B) = v^T B w = \trace(v^T B w) = \trace(w v^T B)$ where in the last step we used the cyclic property of the trace. Since by definition of the gradient $\partial f(A)(B) = <\nabla f(A), B> = \trace(\nabla f(A)^T B)$. And so $\trace(w v^T B)=\trace(\nabla f(A)^T B)$ and by simple pattern matching $\nabla f(A) = v w^T \in V$.
NOTE: Gradients are important because in **practice they are the only thing we will actually be computing**. Derivatives are a much more general idea that doesn't rely on an inner product and doesn't need the output to be a scalar. But in practice we only use them as a tool to get the equations for the gradients. For example, one really important thing that derivatives give us is the chain rule, which doesn't really exist in any clear way for gradients.
## Gradients wrt $V$ and $A$
Here I'll write the gradients of the full transformer layer wrt $Q,K,V$ in the formulation that is useful for flash attention. We presume that we know $\nabla_Y l$, the gradient of the loss $l$ wrt the transoformer outputs $Y$. Thus, we also have the derivative, which is
$$
\frac{\der l}{\der Y} (\dot Y) = \trace(\nabla_Y l^T \dot Y)
$$
Let's start by computing $\nabla_V l$. From the chain rule, it's clear that
$$
\frac{\der l}{\der V} (\dot V) = \frac{\der l}{\der Y} \left(\frac{\der Y}{\der V} (\dot V) \right) = \trace \left(\nabla_Y l^T \frac{\der Y}{\der V} (\dot V) \right)
$$
Since $Y = AV$, which is a linear funciton wrt either $A$ or $V$, it should be clear that $\frac{\der Y}{\der V} (\dot V) = A \dot V$, then
$$
\frac{\der l}{\der V} (\dot V) = \trace \left(\nabla_Y l^T A \dot V \right) = <A^T \nabla_Y l,\dot V>
$$
and so, $\nabla_V l = A^T \nabla_Y l$.
Before we can proceed to compute $\nabla_Q l$ and $\nabla_K l$ it will be useful to compute the derivatives wrt the attention and preatention. But for an elegant derivation we need to introduce a slighly different technique. Up until now we've been computing derivatives and gradients wrt the full objects as opposed to their entries. But when it comes down to the derivative of $A$ wrt $B$, there is a high degree of sparsity. The preatention vector $B_i = K Q_i$ only affects the attention vector $A_i = \gamma(B_i) B_i$. And in the section "Sphere vs Simplex" we worked out the derivatives, $\frac{\der A_i}{\der B_i}$ not $\frac{\der A}{\der B}$. So our approach will be to compute
$$\frac{\der l}{\der B_i} = \frac{\der l}{\der A_i} \frac{\der A_i}{\der B_i} $$
And from this we will reconstruct the full gradient $\nabla_B l$. Ok, first
$$
\frac{\der l}{\der A_i} = \sum_j \frac{\der l}{\der Y_j} \frac{\der Y_j}{\der A_i} = \frac{\der l}{\der Y_i} \frac{\der Y_i}{\der A_i} =
\nabla_{Y_i}l^T V^T
$$
Which choice norm we are using will affect $\frac{\der A_i}{\der B_i}$ and thus the gradients $\nabla_K l, \nabla_Q l$. But we will be able to factor out things nicely. We will write the formula of $\nabla_B l$ for each norm and later we will write $\nabla_K l, \nabla_Q l$ in terms of $\nabla_B l$.
Also, because we will need it later
$$
\nabla_A l = \nabla_Y l V^T
$$
I skip the derivation because it's very easy.
## Preattention Gradinets: Simplex, Sphere and Softmax
From the preattention vectors $B_i = K Q_i$ we will get the attention vectors, which have more geometric structure by being projected to either the simplex or the sphere.
The **simplex** and **sphere** projections are mathematically similar. And the **softmax** one is built on top of the simplex one.
$A_i = \gamma(B_i) B_i$, where $\gamma, \beta: \R^t \to \R$ is a normalization function. Finally $Y_i = V^T A_i$
We will use the fact that the function
$$f_\gamma(x) = \gamma(x) x$$
has derivative
$$
\partial f_\gamma (x) (\dot x) = \partial\gamma(x) (\dot x) x + \gamma(x) \dot x
$$
**simplex** corresponds to the case where the scaling funciton is
$$\alpha(x) = \left(\sum_i x_i \right)^{-1}$$
Which you can easily verify projects the preattention vectors in the positive quadrant into the simplex of dimension $t-1$.
And
$$
\partial \alpha(x) (\dot x) = -\left(\sum_i x_i\right)^{-2} \sum_i \dot x_i = - \alpha(x)^2 u^T \dot x
$$
Where $u \in \R^t$ is the vector of 1s. So we can put the whole thing together
\begin{align}
\partial f_\alpha (x)(\dot x) &= - \alpha(x)^2 u^T \dot x x + \alpha(x) \dot x \\
&= \alpha(x) ( \dot x - \alpha(x) u^T \dot x x) \\
&= \alpha(x) ( I - \alpha(x) x u^T ) \dot x \\
&= \alpha(x) ( I - f_\alpha (x) u^T ) \dot x \\
\end{align}
So $\partial f_\alpha (x) = \alpha(x) ( I - f_\alpha (x) u^T )$.
**Sphere** corresponds to the case where the scaling function is
$$ \beta(x) = \left(\sum_i x_i^2 \right)^{-\frac{1}{2}}$$
We now perform a similar computation for the sphere norm:
\begin{align}
\partial \beta(x)(\dot x) &= -\frac{1}{2} \left(\sum_i x_i^2 \right)^{-\frac{3}{2}} \sum_i 2 x_i \dot x_i \\
&= - \beta(x)^3 x^T \dot x \\
&= - \beta(x)^2 f_\beta(x)^T \dot x
\end{align}
and we put together the whole thing
\begin{align}
\partial f_\beta (x)(\dot x) &= - \beta(x)^2 f_\beta(x)^T \dot x x + \beta(x) \dot x \\
&= \beta(x) ( \dot x - \beta(x) f_\beta(x)^T \dot x x) \\
&= \beta(x) ( I - f_\beta(x) f_\beta(x)^T ) \dot x \\
\end{align}
So that $\partial f_\beta (x) = \beta(x) ( I - f_\beta(x) f_\beta(x)^T )$
**Softmax** Finally, let's compute the same gradient for the softmax. Having already computed the simpelx gradient will be useful because:
$$\text{sf}(x) = f_\alpha ( \exp(x))$$
Where the $\exp(x)$ just means the element wise exponential. Then, the derivative
\begin{align}
\der \sf(x) &= \der f_\alpha(\exp(x)) \; \der \exp (x) \\
&= \alpha(\exp(x)) ( I - f_\alpha (\exp(x)) u^T ) \;D_{\exp(x)} \\
&= \alpha(\exp(x)) D_{\exp(x)} - \sf(x) u^T D_{\exp(x)} \\
&= D_{\sf(x)} - \sf(x) \sf(x)^T \\
\end{align}
### Properties of the derivative of Attention wrt Preattention
It's easy to verify that both, $\alpha$ and $\beta$ satisfy the property $\alpha(r x) = r^{-1} \alpha(x)$ for $r\in \R$. This is important because then $f_\alpha(rx) = \alpha(rx) rx = \alpha(x) x = f_\alpha(x)$. And so scalar multiplication is an invarance of $f_\alpha$ (and also $f_\beta$).
If we look at the two jacobians
$\partial f_\beta (x) = \beta(x) ( I - f_\beta(x) f_\beta(x)^T )$ and $\partial f_\alpha (x) = \alpha(x) ( I - f_\alpha (x) u^T )$ we see that they both take the form
$$\partial f_\gamma(x) = \gamma(x) C$$
Where $C \in \R^{t\times t}$ is a matrix dependant on $x$ but invariant under multiplication by scalars. Thus, both Jacobians are scaled inversly. Which gives us a very simple way to control the magnitude of the gradients.
NOTE: We want to find initialization strategies that are stable wrt $t$. So perhaps using mean inside of $\alpha$ and $\beta$ would be more stable than sum.
NOTE: The softmax behaves quite differenlty on it's fibers. If $\sf(x) = \sf(x')$ then $\der \sf(x) = \der \sf(x')$. This is not hard to prove, but I'll do it another time.
### Gradients wrt preattention $B$
Now that we understand the how the preattention affects the attention, we are ready to incorporates the results from the last section to compute the gradients of the loss wrt to the preattention
**simplex norm**
\begin{align}
\frac{\der l}{\der B_i} &= \frac{\der l}{\der A_i} \frac{\der A_i}{\der B_i} \\
&= \nabla_{A_i} l^T\alpha(B_i)\left( I - A_i u^T \right) \\
&= \alpha(B_i)\left( \nabla_{A_i} l^T - \nabla_{Y_i}l^T V^T A_i u^T \right) \\
&= \alpha(B_i)\left( \nabla_{A_i} l^T - \nabla_{Y_i}l^T Y_i u^T \right) \\
&= \alpha(B_i)\left( \nabla_{A_i} l^T - d_i u^T \right) \\
\end{align}
Where $d_i= \nabla_{Y_i}l^T Y_i$. So
$$
\nabla_{B_i} l = \alpha(B_i)\left( \nabla_{A_i} l - d_i u \right)
$$
From this we can reconstruct the full gradient
$$
\nabla_B l = D_{\alpha(B)}\left( \nabla_A l - d u^T \right)
$$
Where $D_x$ denoes the diagonal matix with $x$ entries, $d\in \R^T$ is the vector with $d_i$ entries and $\alpha(B)\in \R^t$ denotes the vector with entries $\alpha(B_i)$. Note how this last equation is uses two different types of broadcasting. The diagonal matmul scales rows and $d u^T \in \R^{t\times t}$ is a matrix with a copy of $d$ on each column.
**sphere norm**
\begin{align}
\frac{\der l}{\der B_i} &= \frac{\der l}{\der A_i} \frac{\der A_i}{\der B_i} \\
&= \nabla_{A_i} l^T\beta(B_i)\left( I - A_i A_i^T \right) \\
&= \beta(B_i)\left( \nabla_{A_i} l^T - \nabla_{Y_i}l^T V^T A_i A_i^T \right) \\
&= \beta(B_i)\left( \nabla_{A_i} l^T - \nabla_{Y_i}l^T Y_i A_i^T \right) \\
&= \beta(B_i)\left( \nabla_{A_i} l^T - d_i A_i^T \right) \\
\end{align}
So the full gradient wrt $B$ is
$$
\nabla_B l = D_{\beta(B)}\left( \nabla_A l - D_d A \right)
$$
Now is time to compute $\nabla_K l, \nabla_Q l$. We will also handle two cases: linear and multilinear.
**Softmax**
\begin{align}
\frac{\der l}{\der B_i} &= \frac{\der l}{\der A_i} \der \sf(B_i) \\
&= \nabla_{A_i} l^T \; \left (D_{A_i} - A_i A_i^T \right ) \\
&= \nabla_{A_i} l^T D_{A_i} \; - \nabla_{Y_i} l^T V^T A_i A_i^T \\
&= (\nabla_{A_i} l \odot A_i)^T \; - \nabla_{Y_i} l^T Y_i A_i^T \\
&= (\nabla_{A_i} l \odot A_i)^T \; - d_i A_i^T \\
\end{align}
So
$$\nabla_{B_i} l = \nabla_{A_i} l \odot A_i \; - d_i A_i = (\nabla_{A_i} l - d_i u) \odot A_i$$
From this we can reconstruct the full gradient
$$
\nabla_B l = (\nabla_{A} l - d u^T) \odot A
$$
## Gradients wrt $K$ and $Q$
Now that we have the gradients wrt the preattention $B$ we have all we need to compute the gradients wrt $Q$ and $K$.
**Linear case**
$$
\frac{\der l}{\der Q}(\dot Q)
= \trace\left(\nabla_B l^T \frac{\der B}{\der Q}(\dot Q) \right)
= \trace(\nabla_B l^T \dot Q K^T)
= \trace( K^T \nabla_B l^T \dot Q)
$$
So
$$
\nabla_Q l = \nabla_B l K
$$
Similarly,
$$
\frac{\der l}{\der K}(\dot K) = \trace(\nabla_B l^T Q \dot K^T) = \trace(Q^T \nabla_B l \dot K)
$$
So
$$
\nabla_K l = \nabla_B l^T Q
$$
**Multinear case** The main difference with multilinear is that we have $p$ key and query matrices $\stQ 1, \cdots, \stQ p, \stK 1, \cdots, \stK p \in \R^{t \times d}$.
NOTE: When implementing things in the computer we definetely want to have a single rank 3 array to represent all $p$ keys, and the same for the queries. Here I chose to separate it into different varaibles because it's very convinient to compute gradients wrt to a matrix like $\nabla_{\stK m} l$. It just involves linear algebra and and a few other "standard" tools like traces and the Hadamard product (elementwise product). To derive the gradient wrt a rank 3 tensor in an elegant way we would need algebraic tools that I'm not familiar with.
The preactivations $B \in \R^{t\times t}$ are computed with the formula
$$
B = \bigodot_{k=1}^p \stQ k \stK k^T
$$
Using $M^{\odot -1}$ to signify the elementwise inverse of $M$, the derivative of the preactivations wrt $\stQ m$ is
\begin{align}
\frac{\der B}{\der \stQ m}\left (\dot{\stQ m}\right) &= \left ( \bigodot_{k=1, k\neq m}^p \stQ k \stK k^T \right ) \odot \left( (\stQ m + \dot{\stQ m}) \stK m^T \right) - B \\
&= \left ( \bigodot_{k=1, k\neq m}^p \stQ k \stK k^T \right ) \odot \left( \dot{\stQ m} \stK m^T \right) \\
&= B \odot \left( {\stQ m} \stK m^T \right)^{\odot -1} \left( \dot{\stQ m} \stK m^T \right) \\
&= \stP m \odot \left( \dot{\stQ m} \stK m^T \right)
\end{align}
Where $\stP m = B \odot \left( {\stQ m} \stK m^T \right)^{\odot -1} \in \R^{t\times t}$.
Now, to compute the gradient we will need to invoke a couple of basic facts about the Hadamard product and the trace.
**Resut**
$$
\trace(M^T N) = u^T ( M \odot N )u
$$
**collorary**
$$
\trace(M^T (N \odot S)) = \trace((M \odot N)^T S)
$$
Now we are ready to compute the gradient of the loss.
\begin{align}
\frac{\der l}{\der \stQ m}\left (\dot{\stQ m}\right) &= \trace\left( \nabla_B l^T \frac{\der B}{\der \stQ m}\left (\dot{\stQ m}\right) \right) \\
&= \trace\left( \nabla_B l^T \left( \stP m \odot \left( \dot{\stQ m} \stK m^T \right) \right) \right) \\
&= \trace\left( \left ( \nabla_B l \odot \stP m \right)^T \dot{\stQ m} \stK m^T \right) \\
&= \trace\left( \stK m^T \left ( \nabla_B l \odot \stP m \right)^T \dot{\stQ m} \right) \\
\end{align}
And so
$$
\nabla_{\stQ m} l = \left ( \nabla_B l \odot \stP m \right)\stK m
$$
The same exact technique first gives
$$
\frac{\der B}{\der \stK m}\left (\dot{\stK m}\right)
= \stP m \odot \left( {\stQ m} \dot{ \stK m}^T \right)
$$
And then
$$
\frac{\der l}{\der \stK m}\left (\dot{\stK m}\right) =
\trace\left( \stQ m^T \left ( \nabla_B l \odot \stP m \right) \dot{\stK m} \right)
$$
So that the gradient is
$$
\nabla_{\stK m} l = \left ( \nabla_B l \odot \stP m \right)^T \stQ m
$$