kshiteejk

@kshiteejk

Joined on Mar 1, 2021

  • functorch (now torch.func) is a PyTorch module to provide composable transforms like JAX. It provides transforms like grad, vmap, vjp, etc. to allow user to easily compute Jacobian vector products and vectorize the function. And being composable, one can compute per-sample gradients simply by using vmap(grad(model)). Below are some of the things that I got to work on: Adding Batching Rule for vmap vmap is a transform which takes a function func that runs on a single datapoint and returns a function which can handle a batch of data effectively vectorizing it. Semantically, it runs a for-loop over all data points and stacks all the results together (but does it more optimally than the for-loop version.) Example: import torch
     Like  Bookmark
  • Contributing to Open Source Software (OSS) Who Am I? Software Engineer at Quansight (developing PyTorch) Contributed to PyTorch, MXNet, Chainer, etc. Previously Machine Learning Engineer (in NLP and CV) Not from CS background (only language university taught us was C 😔 ) Learnt Python by wandering in the wild!
     Like  Bookmark