$$
\newcommand{\R}{\mathbb{R}}
\newcommand{\Lin}{\text{Lin}}
\newcommand{\trace}{\text{Tr}}
\newcommand{\der}{\partial}
\newcommand{\st}[2]{\stackrel{#1}{#2}}
\newcommand{\stK}[1]{\stackrel{#1}{K}}
\newcommand{\stQ}[1]{\st{#1}{Q}}
\newcommand{\stP}[1]{\st{#1}{P}}
\newcommand{\sf}{\text{sf}}
$$
# Deep Learning Derivatives: the right way
* The notation is a mess, especially in higher dimensions. There are also multiple notations on the same thing that are useful in different moments ($dy/dx$ that prioritizes the input and output objects, $\partial f$ that prioritizes the function).
*Note: This post serves as a support for "Fused Attention Kernels from Scratch". The full glory of the notation and techniques introduced here can't fully be appreciated untill we tackle the problem of computing difficult derivatives like the ones of the attention layer.*
Deep learning sits at the intersection of two deep and important fields of math: linear algebra and calculus. These each have their own rich sets of notation and ideas. Deep learning requires that we use them together, by computing gradients involving matrices. Even people who are comfortable in each individually can struggle when using them together, where the notation [tends to be a mess](https://en.wikipedia.org/wiki/Matrix_calculus).
A big reason for this is the way that calculus is typically taught. The focus is usually placed on derivatives of functions that operate with scalar inputs and scalar output. Then, when we have to tackle derivatives of functions with higher-dimensional inputs and outputs, we construct them by combining many 1-dimensional derivatives. The two main examples of this are the definition of the Jacobian and the gradient:
**jacobian** A function $f: \R^n \to \R^m$ is defined as the $\R^{m\times n}$ matrix
$$
J =
\left[
\begin{array}{c}
\frac{\der y_1}{\der x_1} & \cdots & \frac{\der y_1}{\der x_n} \\
&\vdots \\
\frac{\der y_m}{\der x_1} & \cdots & \frac{\der y_m}{\der x_n} \\
\end{array}
\right]
$$
**gradient** when we have a funciton $f:\R^n\to\R$ that outputs scalars (think $l=f(x)$ outputs the loss) then we define the gradient as:
$$
\nabla_x l =
\left[
\begin{array}{c}
\frac{\der l}{\der x_1} \\
\frac{\der l}{\der x_2} \\
\vdots
\end{array}
\right]
$$
Introducing scalar derivatives first is a smart pedagogical decsion, and it's natural to initially think about higher-dimensional objects as being built out of 1-d components. But this is not the only approach we can follow. It comes with some limitations, for example:
Suppose we are trying to write code that computes a gradient of the loss wrt $x$. We then work out the expression for the term $\frac{\partial l}{\partial x_i}$. We could easily write the code that computes the full gradient by having an inner loop over $i$ that computes the mathematical expression we worked out. But this is a horribly inneficient way to perform the computation. We want to perform the gradient computation via operations that are highly parallelizable, like matrix multiplies, elemntwise operations etc... Thus, it's not enough to work out the math for $\frac{\partial l}{\partial x_i}$, we then need to work out a direct expression for $\nabla_x l$. This is an annoying step that always appears when doing the math for deep learning, where we are always writing code in vector format in order to utilize the hardware.
But there is another way! We can think about funcitons $f:V\to W$ that map between two real vector spaces $V$ and $W and define derivatives and gradients in that setting. One advantage of doing things that way is that we can carry out all the derivations without having to think about those pesky indices all the time. But even more importantly, when we directly work with these high dimensional object, we end up with expressions that directly translate into efficient (parallelizable) code.
* an interesting thig that we will see is that derivatives and gradients really are different beasts. This might seem surprizing since the definitions form above look so similar (almost the same)
* Gradients rely on a choice of inner product. You didn't know it, but it was always lurking beneath the surface.
* This inner product ends up being the tool you want to constantly be using for super simple derivations of gradients that directly translate into efficient code. By the end of this article you will fucking love inner products.
First we will define derivatives in this general setting. Then we will learn what gradients really are. Finally we will work out a handful of backward passes that are relevant for deep leanring.
## Derivatives
The derivative notation is the following. Given a function $f: \R^n \to \R^m$, the derivative $\partial f: \R^n \to \text{Lin}(\R^n, \R^m)$ is a funciton that takes a point in the input space of $f$ and outputs a linear map that approixmates the change of $f$ near that point.
$$f(x + v) \simeq f(x) + \partial f(x)(v)$$
In other words, $\partial f(x)$ is the jacobian of $f$ at the piont $x$.
To talk about a change to $x$ we could use any symbol (above I used $v$) but it's nice to just write $\dot x$. This way it's evident form the notation that $\dot x$ has the same shape as $x$. But the dot does not carry any extra meaning beyond that.
**the chain rule** If we have a composed function $h(x) = g(f(x))$ and we know the derivatives of $f$ and $g$ we can get the derivative of $h$
$$
\der h(x)(\dot x) = \der g(f(x)) (\der f(x)(\dot x))
$$
This is probably different to he way you are accostumed to think about the chain rule, but is in fact the simplest way to represent the idea. It's just a matter of tracking changes.
We as the composition of linear maps
$$\der h(x) = \der g(f(x))\; \der f(x)
$$
**functions of multiple variables** For example our old trusty
$$f(x,W)= \sigma(Wx)$$
For those who know what this menas, we shoud just think of $x,W \in \R^n \oplus \R^{m\times n}$ as being a vector in the direct sum vector space.
$$
\partial f(x, W)(\dot x, \dot W) \simeq f(x+\dot x,w+ \dot W) - f(x,W)
$$
But sometimes we only want to think about how the output changes as we change one of the inputs. One technique to handle that is the indexing discussed below, but another is to define a dependant variable $y=f(x,W)$ and use the notation
$$
\frac{\partial y}{\partial x}:\R^n \to \R^m \;,\;\;\;\;
\frac{\partial y}{\partial W}:\R^{m\times n} \to \R^m
$$
When we are using this notation, we are still talking about the same derivatives as we did before. It's just that we are avoiding to talk explicitly about the function we are taking the derivative of. For example, here we could have fixed a constant $W$ and defined
$$
h(x) = f(x, W)
$$
Then
$$\partial h(x) = \frac{\partial y}{\partial x}$$
This notation is useful for two main reasons:
* it saves us from having to define an explicit function that we take the derivative of. This is very useful in deep learning, for example, where we want to take the derivative of the loss with respect to many things inside the network. It would be very tedious to have to define an explicit function and give it a symbol for everything such derivative we might want to take.
* it saves us from having to evaluate the function at a point. The point is implicit from the wrt part of the expression.
NOTE: We can only use this notation when is completely obvious what the function is. You can only use it if it would be clear how you would define a function $h$ s.t $\der h(x) = \frac{\partial y}{\partial x}$.
**indexing derivatives** If $f: \R^n \to \R^m$, we can also think of $f$ as a function of multiple inputs. Some jacobians are easier to work out by looking at the individual compoents.
$$
\partial f(x)(\dot x) = \sum_{i=1}^n \partial_i f(x) \dot x_i
$$
Each $\partial_i f(x)$ is a column of $\partial f(x)$
## Gradients
Until now we've been talking about notation for derivatives. Now let's have some words about the computation of gradients. The first thing to say about gradients is that they are a much less general object than derivatives. We can only take the gradient of a funciton $f$ if it has a derivative, and if it satisfies two conditions:
* The function $f:V \to \R$ has scalar outputs. It only makes sense to take the derivative of funcitons like the loss.
* The vector space $V$ has an [inner product](https://en.wikipedia.org/wiki/Inner_product_space) structure $<\cdot, \cdot>$
**Gradient Definition** Given an inner product vector space $V$ and a differentiable function $f: V\to \R$, the gradient at $x\in V$ is defined as the unique vector $\nabla f(x) \in V$ such that
$$
<\nabla f(x), \dot x> = \partial f(x)(\dot x)
$$
----
Let's parse this definition. Since $f(x) \in \R$, the derivative at $x$ is a linear map $\der f(x): V \to \R$. But we have have a name for linear the space of linear maps from a vector space to it's underlying field (in this case the reals), it's called the dual space $V^*$. A fundamental result about the dual space is the [Riesz representation theorem](https://en.wikipedia.org/wiki/Riesz_representation_theorem#Example_in_finite_dimensions_using_matrix_transformations), which states that for every covector $\omega \in V^*$ there exists a unique $w\in V$ s.t. for all $v\in V$
$$
\omega(v) = <w, v>
$$
Essentially, we are using the inner product to turn covectors into vectors. And the Reisz theorem tells us this is in fact a 1 to 1 correspondance between $V$ and $V^*$.
Turning covectors into vectors is essential in deep learning. After all, we usually compute gadients wrt parameters so that we can move along that direction. A linear approximation of how the function $f$ changes is not immediately useful to us.
**Gradients as the optimal direction of improvement**
## $\R^n$ Gradients
It's very easy to overlook the nuouances of the gradient because for every ones favourite inner product space, $\R^n$, turning a derivative into a gradient is just a matter of applying the transpose.
Let's see what we mean by this. The natural innerproduct on $\R^n$ is just
$$
<v, w> = v^T w \;\;\; v,w\in \R^n
$$
If $f: \R^n \to \R$ then $\partial f(x) \in \text{Lin}(\R^n, \R)$ so we can think of $\partial f(x) \in \R^{1\times n}$ as a $1\times n$ matrix. Then we can easily see what the gradient is
$$
\nabla f(x)^T v = <\nabla f(x), v> = \partial f(x) (v) = \der f(x) v
$$
Note how in the last step we did $\der f(x)(v) = \partial f(x) v$. Removing the parenthesis indicates that we are using a matrix vector product. Clearly, from the left and the right we can see that
$$
\nabla f(x) = \der f(x)^T
$$
But we can't allways just use the simple trick of "gradient is just the derivative transposed". That would occur if you were using a different inner product in $\R^n$ although we don't really do that in deep learning.
## Matrix Derivatives
### Matrix Indexing Notation
Essentially, we use numpy notation. If $M \in \R^{n\times m}$ then, $M_i \in \R^m$. Essentially, indexing selects a row and turns into a column.
$$
M =
\left[
\begin{array}{c}
M_1^T \\
M_2^T \\
\vdots
\end{array}
\right]
$$
Indexing a matrix $M$ at $i$ just corresponds to doing the multiplication $e_i^T M$, which gives a row vector, and transposing it. Thus $[M]_i = M^T e_i$. Later, when we are computing gradients, we will use this trick quite a bit.
For the $Q,K,V \in \R^{t \times d}$ matrices, and index must be in the range $1 \le i \le t$ and $Q_i \in \R^d$ is a (column) vector.
### The Trace
It might seem weird now, but by the time we finish computing the gradients of the attention layer, you will certainly agree that the real star of the show was the trace. Let $A\in \R^{n\times n}$, then
$$
\trace(A) = \sum_{i=1}^n A_{ii}
$$
There are a few facts about the trace that we will need.
The **trace of transposed** First, if $A: \R^n\to \R^n$ then $\trace(A) = \trace(A^T)$
The **cyclic property** If $A: \R^n \to \R^m$ and $B: \R^m\to \R^n$ then $\trace(AB) = \trace(BA)$. Note that this property doesn't imply that "matrices commute within $\trace$". For example, $\trace(ABC) = \trace(C(AB)) = \trace((BC)A)$. But $\trace(ABC) \neq \trace(BAC)$, at least not generally.
**inner product** $A: \R^n \to \R^m$ and $B: \R^n\to \R^m$ then $\trace(A^TB) = \sum_{i=1}^n A_{ij} B_{ij}$
There is plenty more to be said about the trace, but for us, these are the only properties we will need. They all are very easy to show, so we encourage you to prove them.
### Matrix Gradients
But there is something we do all the time in deep learning, where these ideas do become very useful. Taking derivatives wrt to matrices! For a function $f: \R^{n\times m} \to \R$.
As we showed before, the trace can be used to define an inner product on $\R^{n\times m}$.
$$<A, B> = \trace(A^T B)$$
And we also saw that this inner product is equivalent to the natural inner product between matrices that you would probably think to define $\trace(A^T B) = \sum_{ij} A_{ij} B_{ij}$.
When inputs to $f$ are matrices in $\R^{n\times m}$ we can't think of the derivative $\partial f(M): \R^{n \times m} \to \R$ as a matrix, it's just a linear function that takes a change matrix $\dot M \in \R^{n\times m}$ and approximates the chnage in the output
$$
\der f(M)(\dot M) \simeq f(M + \dot M) - f(M)
$$
Since the derivative isn't even a matrix, it wouldn't even make sense to say that the gadient is the derivative transposed. But when we are computing the gradient $\nabla f(M)$ we are just finding the matrix that satisfies
$$<\nabla f(M), N> = \partial f(M)(N)\;, \;\;\; \forall N\in \R^{n\times m}$$
An algebraic problem that often is very easy to solve, specially with the nice properties that the trace has. Let's look at an example.
Let $f(M)= v^T M w$ for $v\in \R^n, \; w\in \R^m$. We first compute the derivative
$$\partial f(M) (\dot M) = v^T (M+\dot M) w - v^T M w = v^T \dot M w$$
And from that we can get the gradient
\begin{align}
<\nabla f(M), \dot M> &= \partial f(M) (\dot M) \\
&= v^T \dot M w \\
&= \trace(v^T \dot M w) \\
&= \trace(w v^T \dot M) \\
&= <v w^T , \dot M>
\end{align}
And thus, $\nabla f(M) = v w^T$.
## Example 1: gradients of an MLP layer
Before we jump into computing the gradients of our attention layer let's just work out a much simpler problem. The gradients of an MLP layer
$$y = W x \;, \;\;\;\; z = \sigma(y)$$
With Here we don't really care what the loss funciton is, we asume somewhere else we have access to $\nabla z l \in \R^n$, the gradient of the loss wrt $z$.
$$
\frac{\der l}{\der y} = \frac{\der l}{\der z} \frac{\der z}{\der y} = \frac{\der l}{\der z} \der \sigma(y)
$$
The particular equations for the jacobian $\der \sigma(y)$ aren't really important but if the nonlinearity is applied elementwise, we can be sure that the jacobian will be a diagonal matrix since a change on the $i$th input can only effect the $i$th output. The diagonal will contain the elements $\frac{\der z_i}{\der y_i}$ which in the case of a relu, $\sigma(x) = \max(x, 0)$, would just be
$$\frac{\der z_i}{\der y_i} =
\begin{cases}
1, & \text{if}\ x\ge 0 \\
0, & \text{otherwise}
\end{cases}
$$
But whatever $\der \sigma(y)$ is a jacobian. Let's continue working out the two derivatives we actually care about
$$
\frac{\der l}{\der x}\;, \;\;\;\; \frac{\der l}{\der W}
$$
The first one is very easy. $y=Wx$ is a linear function of $x$, so clearly the linear approximation will be
$$
\frac{\der y}{\der x} = W
$$
Put the whole derivative together
$$
\frac{\der l}{\der x} = \frac{\der l}{\der y} \frac{\der y}{\der x} = \frac{\der l}{\der y} W = \frac{\der l}{\der y} \der \sigma(x) W
$$
The derivative of $y$ wrt $W$ is
$$
\frac{\der y}{\der W}(\dot W) \simeq (W + \dot W)x - Wx = \dot W x
$$
And so
\begin{align}
\frac{\der l}{\der W}(\dot W) &= \frac{\der l}{\der y}
\frac{\der y}{\der W}(\dot W) \\
&= \frac{\der l}{\der z} \der \sigma(x) \dot W x \\
&= \trace\left(\frac{\der l}{\der z} \der \sigma(x) \dot W x \right ) \\
&= \trace\left(x \frac{\der l}{\der z} \der \sigma(x) \dot W \right ) \\
&= <\der \sigma(x) \frac{\der l}{\der z}^T x^T , \dot W > \\
&= <\der \sigma(x) \nabla_z l\; x^T , \dot W > \\
\end{align}
So that $\nabla_W l = \der \sigma(x) \nabla_z l \; x^T$. Note that in the last step we used the fact that $\der \sigma(x) = \der \sigma (x)^T$ thaks to the matrix being diagonal.
## Example 2: derivative of the simplex projection
From the preattention vectors $B_i = K Q_i$ we will get the attention vectors, which have more geometric structure by being projected to either the simplex or the sphere.
$A_i = \gamma(B_i) B_i$, where $\gamma, \beta: \R^t \to \R$ is a normalization function. Finally $Y_i = V^T A_i$
We will use the fact that the function
$$f_\gamma(x) = \gamma(x) x$$
has derivative
$$
\partial f_\gamma (x) (\dot x) = \partial\gamma(x) (\dot x) x + \gamma(x) \dot x
$$
**simplex** corresponds to the case where the scaling funciton is
$$\alpha(x) = \left(\sum_i x_i \right)^{-1}$$
Which you can easily verify projects the preattention vectors in the positive quadrant into the simplex of dimension $t-1$.
And
$$
\partial \alpha(x) (\dot x) = -\left(\sum_i x_i\right)^{-2} \sum_i \dot x_i = - \alpha(x)^2 u^T \dot x
$$
Where $u \in \R^t$ is the vector of 1s. So we can put the whole thing together
\begin{align}
\partial f_\alpha (x)(\dot x) &= - \alpha(x)^2 u^T \dot x x + \alpha(x) \dot x \\
&= \alpha(x) ( \dot x - \alpha(x) u^T \dot x x) \\
&= \alpha(x) ( I - \alpha(x) x u^T ) \dot x \\
&= \alpha(x) ( I - f_\alpha (x) u^T ) \dot x \\
\end{align}
So $\partial f_\alpha (x) = \alpha(x) ( I - f_\alpha (x) u^T )$.
## Example 3: derivative of the softmax
## Example 4: gradient of the of the NLL