# MATH594: Numerical Computation with JAX
## Welcome to the Future of High-Performance Computing
---
## About Me
- **Mathematical Roots:** Pure Mathematician (Number Theorist) by training.
- **The Transition:** Applying abstract structures to solve complex, real-world problems.
- **Research Interests:** Data Science, Quantitative Finance, and Operations Research.
---
## [What is JAX?](https://www.youtube.com/watch?v=juy9nrcTBck)
### High-Performance Numerical Computing
* **The Vision:** A Python library designed to bring C++ speeds to a NumPy-like interface.
* **The Synergy:** Merges the familiarity of *NumPy* with the power of *Autograd* (Automatic Differentiation).
* **The Gold Standard:** Powering state-of-the-art AI research at Google DeepMind (e.g., [Gemini](https://aistudio.google.com/models/gemini-3) and [Gemma](https://deepmind.google/models/gemma/)).
---
## [Why JAX?](https://youtu.be/fcX0yl88gRs?si=SK64giUYhhAR1eL-)
### Performance at Any Scale
* **Hardware Agnostic:** Write once, run seamlessly on *CPUs, GPUs, and TPUs*.
* **The XLA Compiler:** Uses [Accelerated Linear Algebra](https://openxla.org/) to optimize and "fuse" operations, squeezing every drop of performance from your hardware.
* **Massive Scalability:** Capable of scaling from a single laptop to clusters of 50,000+ TPUs.
---
<style>.reveal {font-size: 32px;}</style>
## The Core Transformations
### Unleashing the "Magic"
JAX treats Python functions as mathematical objects through three primary transformations:
1. **`jax.jit` (Just-In-Time Compilation):** Compiles code into optimized machine instructions for maximum speed.
2. **`jax.grad` (Automatic Differentiation):** Computes exact gradients—the heartbeat of modern optimization and ML.
3. **`jax.vmap` (Automatic Vectorization):** Instantly converts functions designed for single items to handle massive batches without manual loops.
---
## JAX vs. PyTorch
### A Shift in Philosophy
* **PyTorch (Object-Oriented):** Uses *stateful* objects; gradients are stored as attributes within tensors.
* **JAX (Functional Programming):**
* *Pure Functions:* Predictable outputs with no hidden side effects.
* *Immutable Arrays:* Data is never changed in place, leading to fewer bugs and easier parallelization.
* **The Takeaway:** JAX is built for researchers who want the flexibility of math and the performance of a supercomputer.
---
## Course Goals
### From Theory to Implementation
* **Objective:** Bridge the gap between abstract mathematical concepts and production-grade code.
* **Our Toolkit:**
* [**Differentiable Programming**](https://diffprog.github.io/): Mastering the logic of gradients.
* [**The JAX AI Stack**](https://jaxstack.ai/): Deep learning frameworks built on JAX.
* **Scientific Solvers:** Libraries like [*Lineax*](https://docs.kidger.site/lineax/) and [*Diffrax*](https://docs.kidger.site/diffrax/).
* **Probabilistic Logic:** Tools like [*Blackjax*](https://blackjax-devs.github.io/blackjax/) and [*NumPyro*](https://num.pyro.ai/en/latest/index.html).
---
## Scaling Up: Google Cloud & TPUs
### Taking JAX to the Cloud
* **The Power of TPUs:** Leverage Google’s Tensor Processing Units, purpose-built for the matrix math that JAX excels at.
* **Cloud Infrastructure:** Move beyond local limitations by utilizing [Google Cloud TPU](https://cloud.google.com/tpu/docs) environments for large-scale experiments.
* **Why Cloud?**
* *Speed:* Train models in minutes that would take days on a standard GPU.
* *Efficiency:* Seamlessly manage datasets that exceed local memory.
* *Accessibility:* Access world-class hardware directly through your browser or terminal.
---
## Logistics & Expectations
### Collaborative Discovery
* **Format:** A *discussion-driven* course. We learn by presenting and solving together.
* **Participation:** Come ready to debate technical approaches and tackle "unsolvable" problems.
* **Deliverables:**
* **Version Control:** Collaborative work via [*GitHub*](https://github.com/).
* **Documentation:** All insights submitted via [*Markdown*](https://www.markdownguide.org/) or [*Jupyter Notebook*](https://jupyter.org/).
---
## Grading & Communication
* **Project-Based Learning (80%):**
* *Phase 1:* Final project proposal and research.
* *Phase 2 (June):* Implementation, final delivery, and presentation.
* **Technical Assignments (20%):** Mastering JAX core primitives.
* **The Hub:** All official updates and "intellectual sparring" happen on *Discord*.
---
## Getting Started (Assignment 0)
### Your First Steps
* **Introductions:** Introduce yourself and share your interests to help form diverse project groups.
* **Office Hours:** Schedule a brief meeting with me this month to discuss your goals for the course.
* **Exploration:** Start thinking: *What complex system would you like to differentiate?*
{"title":"JAX_slide","contributors":"[{\"id\":\"9e38ee55-7b6f-408d-a9e3-a3a99f1fde0e\",\"add\":13167,\"del\":8293,\"latestUpdatedAt\":1770084169221}]","description":"Previous life: Pure mathematician, specifically a number theorist."}