# Overview of Google's JAX Ecosystem
In contrast with TensorFlow and PyTorch, JAX has a clean NumPy-like interface which makes it easy to use things like directional derivatives, higher-order derivatives, and differentiating through an optimization procedure.
There are several neural net libraries built on top of JAX. Depending what you're trying to do, you have several options:
- For toy functions and simple architectures (e.g. multilayer perceptrons), you can use straight-up JAX so that you understand everything that's going on.
- [Stax](https://github.com/google/jax/blob/master/jax/experimental/README.md) is a very lightweight neural net package with easy-to-follow source code. It's good for implementing simpler architectures like CIFAR conv nets, and has the advantage that you can understand the whole control flow of the code.
- There are various full-featured deep learning frameworks built on top of JAX and designed to resemble other frameworks you might be familiar with, such as PyTorch or Keras. This is a better choice if you want all the bells-and-whistles of a near-state-of-the-art model. The main choices are [Flax](https://github.com/google/flax), [Haiku](https://github.com/deepmind/dm-haiku), and [Objax](https://github.com/google/objax), and the choice between them might come down to which ones already have a public implementation of something you need. While some of these frameworks involve some magic for defining and training architectures, they still provide a functional API for network computations, making it easy to compute things like Hessian-vector products.
- [Neural Tangents](https://github.com/google/neural-tangents) is a library for working with the neural tangent kernel and infinite width limits of neural nets (see Lecture 6).
You are welcome to use whatever language and framework you like but keep in mind that some of the key concepts, such as directional derivatives or Hessian-vector products, might not be so straightforward to use in some frameworks.
* Jax Excosystem Flax (NN library [https://github.com/google/flax](https://github.com/google/flax "https://github.com/google/flax")) Haiku (NN library [https://github.com/deepmind/dm-haiku](https://github.com/deepmind/dm-haiku "https://github.com/deepmind/dm-haiku")) Chex (Jax utils library [https://github.com/deepmind/chex](https://github.com/deepmind/chex "https://github.com/deepmind/chex")) Equinox (Callable pytrees [https://github.com/patrick-kidger/equinox](https://github.com/patrick-kidger/equinox "https://github.com/patrick-kidger/equinox")) Trax (NN library [https://github.com/google/trax](https://github.com/google/trax "https://github.com/google/trax")) Deep learning libraries Jraph (GNNs in Jax [https://github.com/deepmind/jraph](https://github.com/deepmind/jraph "https://github.com/deepmind/jraph")) Scenic (Vision transformer research lib [https://github.com/google-research/scenic](https://github.com/google-research/scenic "https://github.com/google-research/scenic")) Model parallel transformers ([https://github.com/kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax "https://github.com/kingoflolz/mesh-transformer-jax")) Differentiable programming Brax (Differentiable physics sim [https://github.com/google/brax](https://github.com/google/brax "https://github.com/google/brax"))