# Automatic differentiation
## Basic Usage
### 1. ``jax.grad``
``grad`` is an operator applied on a **pure function**. It outputs the gradient of that function, so its output is still a function. Specially, we may input the variable that we want to derivative.
#### Example 1: the gradient of a single value function
```python=
import jax
import jax.numpy as jnp
from jax import grad
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
```
output:
```
0.070650816
```
#### Example 2: higher order derivative
```python=
import jax
import jax.numpy as jnp
from jax import grad
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
```
output:
```
4.0
10.0
6.0
0.0
```
#### Example 3: multi variable function
Let $f(W, b)$ be a multi variable function from $\mathbb{R}^3 \times \mathbb{R}$ to $\mathbb{R}$. Then we can fund its derivative with respect to $W$ or $b$.
```python=
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')
# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')
```
output:
```
W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)
W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)
b_grad=Array(-0.6900178, dtype=float32)
W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)
b_grad=Array(-0.6900178, dtype=float32)
```
#### Example 4: differentiating with respect to nested lists, tuples, and dicts
```python=
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
```
output:
```
{'W': Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32), 'b': Array(-0.6900178, dtype=float32)}
```
### 2. ``jax.value_and_grad``
``value_and_grad`` outputs not only its gradient but also its function value.
#### Example:
```python=
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
```
output:
```
loss value 2.9729187
loss value 2.9729187
```
### 3. ``jax.jacfwd`` and ``jax.jacrev``
``jax.jacfwd()`` and ``jax.jacrev()`` both output the derivative of a function with respect to each variable. That is, ``jax.jacfwd()`` outputs $$( \frac{\partial f}{x_1}, \frac{\partial f}{x_2}, \dots, \frac{\partial f}{x_n} )$$
The difference between these two operator is that ``jax.jacfwd()`` calculate the derivative forward and that ``jax.jacrev()`` calculate it backward, so the values are identical.
#### Example: find the Hessian matrix
```python=
def f(x):
return jnp.dot(x, x)
def hessian(f):
return jax.jacfwd(jax.grad(f))
hessian(f)(jnp.array([1., 2., 3.]))
```
output:
```
Array([[2., 0., 0.],
[0., 2., 0.],
[0., 0., 2.]], dtype=float32)
```
### 4. ``jax.lax.stop_gradient``
`stop_gradient` prevents a variable from being tracked during differentiation, like `detach` in pytorch. Useful when you want to freeze certain parameters.
#### Example:
```python=
import jax
import jax.numpy as jnp
def f(x):
y = jax.lax.stop_gradient(x) * 2.0
return y * x
dfdx = jax.grad(f)
print(dfdx(3.0))
```
output:
```
6.0
```
### 6. ``jax.jvp``
We may implement **push-forward** with ``jax.jvp``.
#### Example:
```python=
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
```
Here `y` is the function value of `f(W)` and `u` is its corresponding direction of $v$.
### 7. ``jax.vjp``
``jax.vjp`` is like **pull-back** in mathematics.
#### Example
```python=
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
```
## How does Automatic differentiation work
The principle of automatic differentiation (autodiff) is the chain rule of calculus.
1. Any function you write in code is built from primitive operations (addition, multiplication, sine, exponent, etc.).
2. Each primitive has a known derivative rule, like $\frac{d}{dx}sin(x) = cos(x)$.
3. Autodiff breaks your program into a computation graph of these primitives.
4. It then applies the chain rule systematically across the graph to compute derivatives.
That’s it — no approximation (like finite differences) and no symbolic algebra explosion (like SymPy). It’s just mechanical application of the chain rule to the sequence of operations your program executes.