[![](https://i.imgur.com/kgtJWBn.png)](https://dataflowr.github.io/website/modules/2c-jax/) # Autodiff and Backpropagation written by [@marc_lelarge](https://twitter.com/marc_lelarge) (part of the [deep learning course](https://www.dataflowr.com)) ## Jacobian Let ${\bf f}:\mathbb{R}^n\to \mathbb{R}^m$, we define its Jacobian as: \begin{align*} \newcommand{\bbx}{{\bf x}} \newcommand{\bbv}{{\bf v}} \newcommand{\bbw}{{\bf w}} \newcommand{\bbu}{{\bf u}} \newcommand{\bbf}{{\bf f}} \newcommand{\bbg}{{\bf g}} \frac{\partial \bbf}{\partial \bbx} = J_{\bbf}(\bbx) &= \left( \begin{array}{ccc} \frac{\partial f_1}{\partial x_1}&\dots& \frac{\partial f_1}{\partial x_n}\\ \vdots&&\vdots\\ \frac{\partial f_m}{\partial x_1}&\dots& \frac{\partial f_m}{\partial x_n} \end{array}\right)\\ &=\left( \frac{\partial \bbf}{\partial x_1},\dots, \frac{\partial \bbf}{\partial x_n}\right)\\ &=\left( \begin{array}{c} \nabla f_1(\bbx)^T\\ \vdots\\ \nabla f_m(x)^T \end{array}\right) \end{align*} Hence the Jacobian $J_{\bbf}(\bbx)\in \mathbb{R}^{m\times n}$ is a linear map from $\mathbb{R}^n$ to $\mathbb{R}^m$ such that for $\bbx,\bbv \in \mathbb{R}^n$ and $h\in \mathbb{R}$: \begin{align*} \bbf(\bbx+h\bbv) = \bbf(\bbx) + hJ_{\bbf}(\bbx)\bbv +o(h). \end{align*} The term $J_{\bbf}(\bbx)\bbv\in \mathbb{R}^m$ is a Jacobian Vector Product (**JVP**), corresponding to the interpretation where the Jacobian is the linear map: $J_{\bbf}(\bbx):\mathbb{R}^n \to \mathbb{R}^m$, where $J_{\bbf}(\bbx)(\bbv)=J_{\bbf}(\bbx)\bbv$. ## Chain composition In machine learning, we are computing gradient of the loss function with respect to the parameters. In particular, if the parameters are high-dimensional, the loss is a real number. Hence, consider a real-valued function $\bbf:\mathbb{R}^n\stackrel{\bbg_1}{\to}\mathbb{R}^m \stackrel{\bbg_2}{\to}\mathbb{R}^d\stackrel{h}{\to}\mathbb{R}$, so that $\bbf(\bbx) = h(\bbg_2(\bbg_1(\bbx)))\in \mathbb{R}$. We have \begin{align*} \underbrace{\nabla\bbf(\bbx)}_{n\times 1}=\underbrace{J_{\bbg_1}(\bbx)^T}_{n\times m}\underbrace{J_{\bbg_2}(\bbg_1(\bbx))^T}_{m\times d}\underbrace{\nabla h(\bbg_2(\bbg_1(\bbx)))}_{d\times 1}. \end{align*} To do this computation, if we start from the right so that we start with a matrix times a vector to obtain a vector (of size $m$) and we need to make another matrix times a vector, resulting in $O(nm+md)$ operations. If we start from the left with the matrix-matrix multiplication, we get $O(nmd+nd)$ operations. Hence we see that as soon as $m\approx d$, starting for the right is much more efficient. Note however that doing the computation from the right to the left requires to keep in memory the values of $\bbg_1(\bbx)\in\mathbb{R}^m$, and $\bbx\in \mathbb{R}^n$. **Backpropagation** is an efficient algorithm computing the gradient "from the right to the left", i.e. backward. In particular, we will need to compute quantities of the form: $J_{\bbf}(\bbx)^T\bbu \in \mathbb{R}^n$ with $\bbu \in\mathbb{R}^m$ which can be rewritten $\bbu^T J_{\bbf}(\bbx)$ which is a Vector Jacobian Product (**VJP**), correponding to the interpretation where the Jacobian is the linear map: $J_{\bbf}(\bbx):\mathbb{R}^n \to \mathbb{R}^m$, composed with the linear map $\bbu:\mathbb{R}^m\to \mathbb{R}$ so that $\bbu^TJ_{\bbf}(\bbx) = \bbu \circ J_{\bbf}(\bbx)$. **example:** let $\bbf(\bbx, W) = \bbx W\in \mathbb{R}^b$ where $W\in \mathbb{R}^{a\times b}$ and $\bbx\in \mathbb{R}^a$. We clearly have $$ J_{\bbf}(\bbx) = W^T. $$ Note that here, we are slightly abusing notations and considering the partial function $\bbx\mapsto \bbf(\bbx, W)$. To see this, we can write $f_j = \sum_{i}x_iW_{ij}$ so that $$ \frac{\partial \bbf}{\partial x_i}= \left( W_{i1}\dots W_{ib}\right)^T $$ Then recall from definitions that $$ J_{\bbf}(\bbx) = \left( \frac{\partial \bbf}{\partial x_1},\dots, \frac{\partial \bbf}{\partial x_n}\right)=W^T. $$ Now we clearly have $$ J_{\bbf}(W) = \bbx \text{ since, } \bbf(\bbx,W+\Delta W) = \bbf(\bbx,W) + \bbx \Delta W. $$ Note that multiplying $\bbx$ on the right is actually convenient when using broadcasting, i.e. we can take a batch of input vectors of shape $\text{bs}\times a$ without modifying the math above. ## Implementation In PyTorch, `torch.autograd` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. To create a custom [autograd.Function](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function), subclass this class and implement the `forward()` and `backward()` static methods. Here is an example: ```python= class Exp(Function): @staticmethod def forward(ctx, i): result = i.exp() ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): result, = ctx.saved_tensors return grad_output * result # Use it by calling the apply method: output = Exp.apply(input) ``` You can have a look at [Module 2b](https://dataflowr.github.io/website/modules/2b-automatic-differentiation) to learn more about this approach as well as [MLP from scratch](https://dataflowr.github.io/website/homework/1-mlp-from-scratch/). ### Backprop the functional way Here we will implement in `numpy` a different approach mimicking the functional approach of [JAX](https://jax.readthedocs.io/en/latest/index.html) see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#). Each function will take 2 arguments: one being the input `x` and the other being the parameters `w`. For each function, we build 2 **vjp** functions taking as argument a gradient $\bbu$ and corresponding to $J_{\bbf}(\bbx)$ and $J_{\bbf}(\bbw)$ so that these functions return $J_{\bbf}(\bbx)^T \bbu$ and $J_{\bbf}(\bbw)^T \bbu$ respectively. To summarize, for $\bbx \in \mathbb{R}^n$, $\bbw \in \mathbb{R}^d$, and, $\bbf(\bbx,\bbw) \in \mathbb{R}^m$, \begin{align*} {\bf vjp}_\bbx(\bbu) &= J_{\bbf}(\bbx)^T \bbu, \text{ with } J_{\bbf}(\bbx)\in\mathbb{R}^{m\times n}, \bbu\in \mathbb{R}^m\\ {\bf vjp}_\bbw(\bbu) &= J_{\bbf}(\bbw)^T \bbu, \text{ with } J_{\bbf}(\bbw)\in\mathbb{R}^{m\times d}, \bbu\in \mathbb{R}^m \end{align*} Then backpropagation is simply done by first computing the gradient of the loss and then composing the **vjp** functions in the right order. ### Code - intro to JAX: autodiff the functional way [autodiff_functional_empty.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/autodiff_functional_empty.ipynb) and its solution [autodiff_functional_sol.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/autodiff_functional_sol.ipynb) - Linear regression in JAX [linear_regression_jax.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/linear_regression_jax.ipynb) [![](https://i.imgur.com/kgtJWBn.png)](https://dataflowr.github.io/website/modules/2c-jax/) ###### tags: `public` `dataflowr` `jax` `autodiff`