functorch (now torch.func) is a PyTorch module to provide composable transforms like JAX. It provides transforms like grad, vmap, vjp, etc. to allow user to easily compute Jacobian vector products and vectorize the function. And being composable, one can compute per-sample gradients simply by using vmap(grad(model)).
Below are some of the things that I got to work on:
Adding Batching Rule for vmap
vmap is a transform which takes a function func that runs on a single datapoint and returns a function which can handle a batch of data effectively vectorizing it. Semantically, it runs a for-loop over all data points and stacks all the results together (but does it more optimally than the for-loop version.)
Example:
import torch