functorch (now torch.func) is a PyTorch module to provide composable transforms like JAX. It provides transforms like grad, vmap, vjp, etc. to allow user to easily compute Jacobian vector products and vectorize the function. And being composable, one can compute per-sample gradients simply by using vmap(grad(model)).

Below are some of the things that I got to work on:

Adding Batching Rule for vmap

vmap is a transform which takes a function func that runs on a single datapoint and returns a function which can handle a batch of data effectively vectorizing it. Semantically, it runs a for-loop over all data points and stacks all the results together (but does it more optimally than the for-loop version.)

Example:

import torch

# Written to handle only single sample.
# Calling it with batched input, will
# result in incorrect output.
def my_simple_model(input, weight, bias):
    return input.sum(0) @ weight + bias

batched_inputs = torch.randn(3, 3, 3, 3)
weight = torch.randn(3, 1) * 5
bias = torch.randn([])

# For Loop version
expected = []
for input in batched_inputs:
    expected.append(my_simple_model(input, weight, bias))
expected = torch.stack(expected)

# Vmap
actual = torch.vmap(my_simple_model, in_dims=(0, None, None))(batched_inputs, weight, bias)

# Incorrectly calling the function (silently incorrect result)
incorrect = my_simple_model(batched_inputs, weight, bias)

# Results don't match.
assert torch.abs(expected - incorrect).sum() > 1e-3

# Verify that the results match.
torch.testing.assert_close(expected, actual)

To support vmap for PyTorch operators, we need to specify the batching rule i.e. how to apply the given operator for a batched input. This is similar to how PyTorch internally specifies the rule for gradient computation for operators. Batching rule is essentially a function which takes one or multiple batched inputs and computes the batched operation. In the above example to support vmap for my_simple_model, we need to know the batching rule for torch.sum, @/torch.matmul and +/torch.add to be able to vectorize our model. PyTorch has a lot of operators and we need to have coverage for all the operators to seamlessly support optimized vmap (yes, there is a for-loop fallback in case an operator is not supported so as not to crash the code).

PyTorch operators can be very roughly categorized as primitive (internally CompositeExplicitAutograd) vs composite (internally CompositeImplicitAutograd). Composite operators are derived from the primitive operators. So, to have complete coverage, we need to have batching rules for all the primitive operators and we get the rules for composite operators for free.

To add a batching rules for a primitive operator, we can

Composite Compliance

Above we mentioned that we get batching rules for free for composite operators. But that is true only if the operator follows a few constraints like they should not access the data pointer of the tensor, they should not call out= variants of the operators, etc. Unfortunately, operators which claim to be composite can sometimes not follow these constraints and that works when you are using plain eager PyTorch but can lead to problems with functorch transforms (eg. what does accessing item or data_ptr on BatchedTensor mean?).

Testing for Composite Compliance

The idea is to write extensive tests to verify that the constraints are met. This is achieved by having a new subclass and with __torch_dispatch__ mechanism, we error on the non-compliant behaviour. We run the test on the actual operator, their backward formula and their forward formula. The reason for having the test on backward and forward formula is because we can have vmap(vjp(fn)) or vmap(jvp(fn)) which requires them to be Composite Compliant.

Fixing the operators on case by case basis.

Once we had the tests and the list of failing operators, it was a matter of going through the list, verifying what was the cause of the operator being non-compliant and devising a fix for the same. The issue tracker can be found here.

Support for chunk_size in vmap and jacrev

Computing the Jacobian can require a lot of memory and related issue were raised by the users. To mitigate this, we added support to compute the jacrev and vmap in smaller chunks decided based on chunk_size argument to reduce the peak memory usage during the computation. Using this argument user can specify the number of rows of the Jacobian to be computed at once. Same argument was added to vmap for similar purpose. This feature was added in jacrev PR and vmap PR.

Support for linearize transform

jvp transform computes both f(x) and jacobian-vector product. So, even if one wants to compute jvp for fixed inputs, jvp transform ends up repeating the evaluation of f(x). For such scenarios, one can use linearize which is useful if jvp is to be computed multiple times at fixed primals. However, to achieve this, linearize saves intermediate computation and has higher memory requrements than directly applying jvp. linearize was added in this PR

Supporting transforms for torch.compile

PyTorch 2.0 provided a new compilation stack torch.compile. functorch was missing jit transform compared to JAX, so this is opened the ability to compile the present transforms. To understand, how we can compile these transforms. We need to understand the three layers of the compilation stack namely dynamo, aot_autograd, inductor. dynamo and aot_autograd deal mainly with graph capture and lowering of the captured operators into more primitive operators. inductor is more like a compiler which takes the captured graph and actually applies fusion and other optimisations before generating specialized code.

To get an idea of what happens at different stages of the stack, let's compile a simple program with debug mode.

# Run this file with `TORCH_COMPILE_DEBUG=1` env flag. import torch def fn(x): return torch.sin(x) + torch.square(x) torch.compile(fn)(torch.randn(3, 3))

dynamo: dynamo's job is to capture the PyTorch program being traced and represent it in the FX graph format. The FX graph captured by Dynamo captures the PyTorch operations at public API level (eg. torch.sin). Below is the graph capture by dynamo.

class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ # File: test/test_scratch.py:334, code: return torch.sin(x) + torch.square(x) sin = torch.sin(l_x_) square = torch.square(l_x_); l_x_ = None add = sin + square; sin = square = None return (add,)

aot_autograd: aot_autograd traces through all the PyTorch operations to generate a FX graph but this time with the aten operators. It also decomposes composite operations into more primitive ones (eg. torch.square which is composite will get traced down to torch.pow(x, 2)). aot_autograd also handles generating the backward graph if requested. That is where the name comes from ahead of time autograd / aot_autograd.

def forward(self, arg0_1: f32[3, 3]): # File: test/test_scratch.py:334, code: return torch.sin(x) + torch.square(x) sin: f32[3, 3] = torch.ops.aten.sin.default(arg0_1) pow_1: f32[3, 3] = torch.ops.aten.pow.Tensor_Scalar(arg0_1, 2); arg0_1 = None add: f32[3, 3] = torch.ops.aten.add.Tensor(sin, pow_1); sin = pow_1 = None return (add,)

inductor: As discussed above, it is inductor's job to apply optimisations and generate specialised code. In this case, it has fused sin and square to run within the same for-loop. This allows the generated program to do more compute per read/write effectively improving the memory bandwith utilization.

from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch import empty_strided, device from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile() cpp_fused_add_cos_sin_0 = async_compile.cpp(''' #include "/tmp/torchinductor_kshiteej/ib/cibrnuq56cxamjj4krp4zpjvsirbmlolpbnmomodzyd46huzhdw7.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr0) { { for(long i0=static_cast<long>(0L); i0<static_cast<long>(8L); i0+=static_cast<long>(8L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i0)); auto tmp1 = tmp0.sin(); auto tmp2 = tmp0.cos(); auto tmp3 = tmp1 + tmp2; tmp3.store(out_ptr0 + static_cast<long>(i0)); } #pragma omp simd simdlen(4) for(long i0=static_cast<long>(8L); i0<static_cast<long>(9L); i0+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(i0)]; auto tmp1 = std::sin(tmp0); auto tmp2 = std::cos(tmp0); auto tmp3 = tmp1 + tmp2; out_ptr0[static_cast<long>(i0)] = tmp3; } } } ''') async_compile.wait(globals()) del async_compile def call(args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (3, 3), (3, 1)) buf0 = empty_strided((3, 3), (3, 1), device='cpu', dtype=torch.float32) cpp_fused_add_cos_sin_0(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr())) del arg0_1 return (buf0, )

With basic idea of how torch.compile works, we can now talk about how we can support transforms. Given that aot_autograd is able to trace through the transforms, we only need to teach dynamo to verify if the user function to be transformed doesn't have side-effects or graph breaks. In that case, we can just put the functorch transform in the graph and let the lower part of the stack handle the rest. However, if the function can't be traced successfully due to not satisfying the above constraints, we just fallback to the eager implementation and this part of the code is not compiled.

Let us have a look what dynamo and aot_autograd generates when we compile program with vmap.
Example

# Run this file with `TORCH_COMPILE_DEBUG=1` env flag. import torch # function to be vmapped. def fn(x): return torch.sum(x, dim=0) def wrapper_fn(x): return torch.func.vmap(fn)(x) B = 2 torch.compile(wrapper_fn)(torch.randn(B, 3))

dynamo output is as follows. The first GraphModule is the current program and it calls vmap on the traced representation of user passed function. The second GraphModule corresponds to the user passed function.

class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ # File: torch/_functorch/apis.py:182, code: _check_randomness_arg(randomness) _check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error') # File: torch/_functorch/apis.py:188, code: return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs) select = l_x_.select(0, 0) # implementation detail vmap_body_0 = self.vmap_body_0 vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None call = vmap_proxy.__call__(l_x_); vmap_proxy = l_x_ = None return (call,) class GraphModule(torch.nn.Module): def forward(self, select): # File: test/test_scratch.py:334, code: return torch.sum(x, dim=0) sum_1 = torch.sum(select, dim = 0); select = None return sum_1

aot_autograd is traces through the transformed graph that is the graph generated after vmap has been applied. That is why, the call to sum has dim=1 instead of dim=0 as we did in user passed function (because with vmap, we have a leading batch dimension in this case).

def forward(self, arg0_1: f32[2, 3]): # File: torch/_functorch/apis.py:188, code: return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs) sum_1: f32[2] = torch.ops.aten.sum.dim_IntList(arg0_1, [1]); arg0_1 = None return (sum_1,)

Currently for PyTorch 2.1, we support compiling grad and vmap with some limitations. In future, we plan to support all transforms will be supported with minimum limitations.

Closing Remarks
All the work has been made possible with help and guidance from amazing folks at PyTorch team at Quansight, Mario Lezcan (my team-lead) and Richard Zou and Horace He (from Meta).