[![](https://i.imgur.com/kgtJWBn.png)](https://dataflowr.github.io/website/modules/2c-jax/)
# Autodiff and Backpropagation
written by [@marc_lelarge](https://twitter.com/marc_lelarge) (part of the [deep learning course](https://www.dataflowr.com))
## Jacobian
Let ${\bf f}:\mathbb{R}^n\to \mathbb{R}^m$, we define its Jacobian as:
\begin{align*}
\newcommand{\bbx}{{\bf x}}
\newcommand{\bbv}{{\bf v}}
\newcommand{\bbw}{{\bf w}}
\newcommand{\bbu}{{\bf u}}
\newcommand{\bbf}{{\bf f}}
\newcommand{\bbg}{{\bf g}}
\frac{\partial \bbf}{\partial \bbx} = J_{\bbf}(\bbx) &= \left( \begin{array}{ccc}
\frac{\partial f_1}{\partial x_1}&\dots& \frac{\partial f_1}{\partial x_n}\\
\vdots&&\vdots\\
\frac{\partial f_m}{\partial x_1}&\dots& \frac{\partial f_m}{\partial x_n}
\end{array}\right)\\
&=\left( \frac{\partial \bbf}{\partial x_1},\dots, \frac{\partial \bbf}{\partial x_n}\right)\\
&=\left(
\begin{array}{c}
\nabla f_1(\bbx)^T\\
\vdots\\
\nabla f_m(x)^T
\end{array}\right)
\end{align*}
Hence the Jacobian $J_{\bbf}(\bbx)\in \mathbb{R}^{m\times n}$ is a linear map from $\mathbb{R}^n$ to $\mathbb{R}^m$ such that for $\bbx,\bbv \in \mathbb{R}^n$ and $h\in \mathbb{R}$:
\begin{align*}
\bbf(\bbx+h\bbv) = \bbf(\bbx) + hJ_{\bbf}(\bbx)\bbv +o(h).
\end{align*}
The term $J_{\bbf}(\bbx)\bbv\in \mathbb{R}^m$ is a Jacobian Vector Product (**JVP**), corresponding to the interpretation where the Jacobian is the linear map: $J_{\bbf}(\bbx):\mathbb{R}^n \to \mathbb{R}^m$, where $J_{\bbf}(\bbx)(\bbv)=J_{\bbf}(\bbx)\bbv$.
## Chain composition
In machine learning, we are computing gradient of the loss function with respect to the parameters. In particular, if the parameters are high-dimensional, the loss is a real number. Hence, consider a real-valued function $\bbf:\mathbb{R}^n\stackrel{\bbg_1}{\to}\mathbb{R}^m \stackrel{\bbg_2}{\to}\mathbb{R}^d\stackrel{h}{\to}\mathbb{R}$, so that $\bbf(\bbx) = h(\bbg_2(\bbg_1(\bbx)))\in \mathbb{R}$. We have
\begin{align*}
\underbrace{\nabla\bbf(\bbx)}_{n\times 1}=\underbrace{J_{\bbg_1}(\bbx)^T}_{n\times m}\underbrace{J_{\bbg_2}(\bbg_1(\bbx))^T}_{m\times d}\underbrace{\nabla h(\bbg_2(\bbg_1(\bbx)))}_{d\times 1}.
\end{align*}
To do this computation, if we start from the right so that we start with a matrix times a vector to obtain a vector (of size $m$) and we need to make another matrix times a vector, resulting in $O(nm+md)$ operations. If we start from the left with the matrix-matrix multiplication, we get $O(nmd+nd)$ operations. Hence we see that as soon as $m\approx d$, starting for the right is much more efficient. Note however that doing the computation from the right to the left requires to keep in memory the values of $\bbg_1(\bbx)\in\mathbb{R}^m$, and $\bbx\in \mathbb{R}^n$.
**Backpropagation** is an efficient algorithm computing the gradient "from the right to the left", i.e. backward. In particular, we will need to compute quantities of the form: $J_{\bbf}(\bbx)^T\bbu \in \mathbb{R}^n$ with $\bbu \in\mathbb{R}^m$ which can be rewritten $\bbu^T J_{\bbf}(\bbx)$ which is a Vector Jacobian Product (**VJP**), correponding to the interpretation where the Jacobian is the linear map: $J_{\bbf}(\bbx):\mathbb{R}^n \to \mathbb{R}^m$, composed with the linear map $\bbu:\mathbb{R}^m\to \mathbb{R}$ so that $\bbu^TJ_{\bbf}(\bbx) = \bbu \circ J_{\bbf}(\bbx)$.
**example:** let $\bbf(\bbx, W) = \bbx W\in \mathbb{R}^b$ where $W\in \mathbb{R}^{a\times b}$ and $\bbx\in \mathbb{R}^a$. We clearly have
$$
J_{\bbf}(\bbx) = W^T.
$$
Note that here, we are slightly abusing notations and considering the partial function $\bbx\mapsto \bbf(\bbx, W)$. To see this, we can write $f_j = \sum_{i}x_iW_{ij}$ so that
$$
\frac{\partial \bbf}{\partial x_i}= \left( W_{i1}\dots W_{ib}\right)^T
$$
Then recall from definitions that
$$
J_{\bbf}(\bbx) = \left( \frac{\partial \bbf}{\partial x_1},\dots, \frac{\partial \bbf}{\partial x_n}\right)=W^T.
$$
Now we clearly have
$$
J_{\bbf}(W) = \bbx \text{ since, } \bbf(\bbx,W+\Delta W) = \bbf(\bbx,W) + \bbx \Delta W.
$$
Note that multiplying $\bbx$ on the right is actually convenient when using broadcasting, i.e. we can take a batch of input vectors of shape $\text{bs}\times a$ without modifying the math above.
## Implementation
In PyTorch, `torch.autograd` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. To create a custom [autograd.Function](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function), subclass this class and implement the `forward()` and `backward()` static methods. Here is an example:
```python=
class Exp(Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
# Use it by calling the apply method:
output = Exp.apply(input)
```
You can have a look at [Module 2b](https://dataflowr.github.io/website/modules/2b-automatic-differentiation) to learn more about this approach as well as [MLP from scratch](https://dataflowr.github.io/website/homework/1-mlp-from-scratch/).
### Backprop the functional way
Here we will implement in `numpy` a different approach mimicking the functional approach of [JAX](https://jax.readthedocs.io/en/latest/index.html) see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#).
Each function will take 2 arguments: one being the input `x` and the other being the parameters `w`. For each function, we build 2 **vjp** functions taking as argument a gradient $\bbu$ and corresponding to $J_{\bbf}(\bbx)$ and $J_{\bbf}(\bbw)$ so that these functions return $J_{\bbf}(\bbx)^T \bbu$ and $J_{\bbf}(\bbw)^T \bbu$ respectively. To summarize, for $\bbx \in \mathbb{R}^n$, $\bbw \in \mathbb{R}^d$, and, $\bbf(\bbx,\bbw) \in \mathbb{R}^m$,
\begin{align*}
{\bf vjp}_\bbx(\bbu) &= J_{\bbf}(\bbx)^T \bbu, \text{ with } J_{\bbf}(\bbx)\in\mathbb{R}^{m\times n}, \bbu\in \mathbb{R}^m\\
{\bf vjp}_\bbw(\bbu) &= J_{\bbf}(\bbw)^T \bbu, \text{ with } J_{\bbf}(\bbw)\in\mathbb{R}^{m\times d}, \bbu\in \mathbb{R}^m
\end{align*}
Then backpropagation is simply done by first computing the gradient of the loss and then composing the **vjp** functions in the right order.
### Code
- intro to JAX: autodiff the functional way [autodiff_functional_empty.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/autodiff_functional_empty.ipynb) and its solution [autodiff_functional_sol.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/autodiff_functional_sol.ipynb)
- Linear regression in JAX [linear_regression_jax.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/linear_regression_jax.ipynb)
[![](https://i.imgur.com/kgtJWBn.png)](https://dataflowr.github.io/website/modules/2c-jax/)
###### tags: `public` `dataflowr` `jax` `autodiff`