# MATH594: Numerical Computation with JAX ## Welcome to the Future of High-Performance Computing --- ## About Me - **Mathematical Roots:** Pure Mathematician (Number Theorist) by training. - **The Transition:** Applying abstract structures to solve complex, real-world problems. - **Research Interests:** Data Science, Quantitative Finance, and Operations Research. --- ## [What is JAX?](https://www.youtube.com/watch?v=juy9nrcTBck) ### High-Performance Numerical Computing * **The Vision:** A Python library designed to bring C++ speeds to a NumPy-like interface. * **The Synergy:** Merges the familiarity of *NumPy* with the power of *Autograd* (Automatic Differentiation). * **The Gold Standard:** Powering state-of-the-art AI research at Google DeepMind (e.g., [Gemini](https://aistudio.google.com/models/gemini-3) and [Gemma](https://deepmind.google/models/gemma/)). --- ## [Why JAX?](https://youtu.be/fcX0yl88gRs?si=SK64giUYhhAR1eL-) ### Performance at Any Scale * **Hardware Agnostic:** Write once, run seamlessly on *CPUs, GPUs, and TPUs*. * **The XLA Compiler:** Uses [Accelerated Linear Algebra](https://openxla.org/) to optimize and "fuse" operations, squeezing every drop of performance from your hardware. * **Massive Scalability:** Capable of scaling from a single laptop to clusters of 50,000+ TPUs. --- <style>.reveal {font-size: 32px;}</style> ## The Core Transformations ### Unleashing the "Magic" JAX treats Python functions as mathematical objects through three primary transformations: 1. **`jax.jit` (Just-In-Time Compilation):** Compiles code into optimized machine instructions for maximum speed. 2. **`jax.grad` (Automatic Differentiation):** Computes exact gradients—the heartbeat of modern optimization and ML. 3. **`jax.vmap` (Automatic Vectorization):** Instantly converts functions designed for single items to handle massive batches without manual loops. --- ## JAX vs. PyTorch ### A Shift in Philosophy * **PyTorch (Object-Oriented):** Uses *stateful* objects; gradients are stored as attributes within tensors. * **JAX (Functional Programming):** * *Pure Functions:* Predictable outputs with no hidden side effects. * *Immutable Arrays:* Data is never changed in place, leading to fewer bugs and easier parallelization. * **The Takeaway:** JAX is built for researchers who want the flexibility of math and the performance of a supercomputer. --- ## Course Goals ### From Theory to Implementation * **Objective:** Bridge the gap between abstract mathematical concepts and production-grade code. * **Our Toolkit:** * [**Differentiable Programming**](https://diffprog.github.io/): Mastering the logic of gradients. * [**The JAX AI Stack**](https://jaxstack.ai/): Deep learning frameworks built on JAX. * **Scientific Solvers:** Libraries like [*Lineax*](https://docs.kidger.site/lineax/) and [*Diffrax*](https://docs.kidger.site/diffrax/). * **Probabilistic Logic:** Tools like [*Blackjax*](https://blackjax-devs.github.io/blackjax/) and [*NumPyro*](https://num.pyro.ai/en/latest/index.html). --- ## Scaling Up: Google Cloud & TPUs ### Taking JAX to the Cloud * **The Power of TPUs:** Leverage Google’s Tensor Processing Units, purpose-built for the matrix math that JAX excels at. * **Cloud Infrastructure:** Move beyond local limitations by utilizing [Google Cloud TPU](https://cloud.google.com/tpu/docs) environments for large-scale experiments. * **Why Cloud?** * *Speed:* Train models in minutes that would take days on a standard GPU. * *Efficiency:* Seamlessly manage datasets that exceed local memory. * *Accessibility:* Access world-class hardware directly through your browser or terminal. --- ## Logistics & Expectations ### Collaborative Discovery * **Format:** A *discussion-driven* course. We learn by presenting and solving together. * **Participation:** Come ready to debate technical approaches and tackle "unsolvable" problems. * **Deliverables:** * **Version Control:** Collaborative work via [*GitHub*](https://github.com/). * **Documentation:** All insights submitted via [*Markdown*](https://www.markdownguide.org/) or [*Jupyter Notebook*](https://jupyter.org/). --- ## Grading & Communication * **Project-Based Learning (80%):** * *Phase 1:* Final project proposal and research. * *Phase 2 (June):* Implementation, final delivery, and presentation. * **Technical Assignments (20%):** Mastering JAX core primitives. * **The Hub:** All official updates and "intellectual sparring" happen on *Discord*. --- ## Getting Started (Assignment 0) ### Your First Steps * **Introductions:** Introduce yourself and share your interests to help form diverse project groups. * **Office Hours:** Schedule a brief meeting with me this month to discuss your goals for the course. * **Exploration:** Start thinking: *What complex system would you like to differentiate?*
{"title":"JAX_slide","contributors":"[{\"id\":\"9e38ee55-7b6f-408d-a9e3-a3a99f1fde0e\",\"add\":13167,\"del\":8293,\"latestUpdatedAt\":1770084169221}]","description":"Previous life: Pure mathematician, specifically a number theorist."}
    63 views