# Just-in-time compilation ## Why do we need JIT? The goal of using JAX is to carry out Numerical Computation with high efficiency, especially on a GPU/TPU. Let's start with an example. ```python= import jax import jax.numpy as jnp from jax import grad, jit from jax.lax import while_loop jax.config.update("jax_enable_x64", True) # --- Setup --- N = 3 A = jnp.diag(jnp.arange(1, N + 1, dtype=jnp.float64)) def f(x): Ax = A @ x return jnp.dot(Ax, Ax) g = grad(f) # --- Armijo Line Search --- def armijo_line_search(x, grad_val, sigma=0.1, beta=0.5, alpha_init=1.0): p = -grad_val def cond_fun(alpha): return f(x + alpha * p) > f(x) + sigma * alpha * jnp.dot(grad_val, p) def body_fun(alpha): return alpha * beta alpha = while_loop(cond_fun, body_fun, alpha_init) return alpha ``` ### Version A. ```python # --- Set up an update step --- def step(x): gradient = g(x) step_size = armijo_line_search(x, gradient) x_next = x - step_size * gradient return x_next, f(x_next), jnp.linalg.norm(gradient) # --- Gradient Descent --- x_current = jnp.ones(N) step(x_current)[0].block_until_ready() %timeit step(x_current)[0].block_until_ready() ``` Executing time: ```text 81.4 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` ### Version B. ```Python @jit def step(x): gradient = g(x) step_size = armijo_line_search(x, gradient) x_next = x - step_size * gradient return x_next, f(x_next), jnp.linalg.norm(gradient) # --- Gradient Descent --- x_current = jnp.ones(N) step(x_current)[0].block_until_ready() %timeit step(x_current)[0].block_until_ready() ``` ```text 724 µs ± 298 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` You can see a huge gap between the two results. ## JIT: The Computation Blueprint for a factory Unlike Python, which executes commands one by one, JIT reads the whole program, draws an optimal "blueprint," and then builds a "factory" to produce results efficiently. ### Step 1. Tracing and Drawing the Blueprint ```python= import jax import jax.numpy as jnp global_list = [] def f(x): global_list.append(x) y = 2*x+1 k = 3.0 return y * k print(jax.make_jaxpr(f)(1.0)) ``` ```text { lambda ; a:f64[]. let b:f64[] = mul 2.0 a c:f64[] = add b 1.0 d:f64[] = mul c 3.0 in (d,) } ``` You can find that the tracer is only concerned with the **shape** and **dtype** of the data, not its actual value. The `jaxpr` shown above is the blueprint we mentioned, which records all the computation steps. ### Step 2. Compiling and Producing rapidly JAX will take this blueprint and use it to compile the highly efficient machine code. Afterward, all inputs with the same shape and type will be processed by this "factory" at high speed, completely bypassing the slower Python interpreter. ## Static Hints to JIT You might run into some problems when you first start. ### (A.) ```python= import jax import jax.numpy as jnp def abs(x): if x<0: return -x else: return x jax.jit(f)(10) # Error ``` ``` --------------------------------------------------------------------------- TracerBoolConversionError Traceback (most recent call last) /tmp/ipython-input-1497352890.py in <cell line: 0>() 10 return x 11 ---> 12 jax.jit(abs)(10) [... skipping hidden 14 frame] 1 frames [... skipping hidden 1 frame] /usr/local/lib/python3.12/dist-packages/jax/_src/core.py in error(self, arg) 1602 if fun is bool: 1603 def error(self, arg): -> 1604 raise TracerBoolConversionError(arg) 1605 elif fun in (hex, oct, operator.index): 1606 def error(self, arg): TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function abs at /tmp/ipython-input-1497352890.py:6 for jit. This concrete value was not available in Python because it depends on the value of the argument x. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError ``` ### (B.) ```python= import jax import jax.numpy as jnp def g(x,n): i = 0 while i < n: i += 1 return x + i jax.jit(g)(10, 20) # Raises an error ``` ``` --------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipython-input-760696092.py in <cell line: 0>() 7 return x + i 8 ----> 9 jax.jit(g)(10, 20) # Raises an error [... skipping hidden 18 frame] /usr/local/lib/python3.12/dist-packages/jax/_src/api.py in _check_input_dtype_revderiv(name, holomorphic, allow_int, x) 509 dtypes.issubdtype(aval.dtype, np.bool_)): 510 if not allow_int: --> 511 raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " 512 f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " 513 "If you want to use Boolean- or integer-valued inputs, use vjp " TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True. ``` ### (C.) ```python= import jax import jax.numpy as jnp def l(x,n): i=0 while jnp.log(x)>=n: x=jnp.log(x) i+=1 return i jax.jit(l)(10,1) # Error ``` ``` --------------------------------------------------------------------------- TracerBoolConversionError Traceback (most recent call last) /tmp/ipython-input-2984054809.py in <cell line: 0>() 8 i+=1 9 return i ---> 10 jax.jit(g)(10,1) [... skipping hidden 14 frame] 1 frames [... skipping hidden 1 frame] /usr/local/lib/python3.12/dist-packages/jax/_src/core.py in error(self, arg) 1602 if fun is bool: 1603 def error(self, arg): -> 1604 raise TracerBoolConversionError(arg) 1605 elif fun in (hex, oct, operator.index): 1606 def error(self, arg): TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function g at /tmp/ipython-input-2984054809.py:4 for jit. This concrete value was not available in Python because it depends on the values of the arguments x and n. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError ``` ### Why does it fail? * While compiling, `@jit` needs to draw a computation blueprint with a **fixed structure**. * However, whether an `if` statement or a `while` loop runs depends on the result of a comparison, which is part of the computation structure itself. * During tracing, JAX only knows the shape and type of the inputs, not their values, so it can't perform the comparison. In other words, JAX doesn't know which branch to take or how many times a loop will run, so it can't draw a blueprint with a fixed structure, and the compilation fails. How to fix this? * We can use `static_argnums` or `static_argnames` to give JIT an opportunity to "peek" at the actual value of a specific argument. * However, whenever this static value changes, JAX needs to recompile. Therefore, this method is only suitable for arguments that have a finite, and preferably small, number of possible values (e.g., `True`/`False`). ### (A'.) ```python= import jax import jax.numpy as jnp def abs(x): if x<0: return -x else: return x abs_jit_correct = jax.jit(abs, static_argnums=0) print(abs_jit_correct(1)) ``` ``` 1 ``` In (A'), we can fix a specific variable in the conditional judgment by its position using `static_argnums=0`. **Be careful**: the keyword is `static_argnums`, not `static_argnum`. ### (B'.) ```python= import jax import jax.numpy as jnp def g(x,n): i = 0 while i < n: i += 1 return x + i g_jit_correct = jax.jit(g, static_argnames=['n']) print(g_jit_correct(10, 20)) ``` ``` 30 ``` In (B'), we can fix a specific variable by its name using `static_argnames=['n']`. ### (B''.) ```python= from functools import partial @partial(jax.jit, static_argnames=['n']) def g_jit_decorated(x, n): i = 0 while i < n: i += 1 return x + i print(g_jit_decorated(10, 20)) ``` ``` 30 ``` In (B''), we use jit as a decorator. ### (C'.) ```python= import jax import jax.numpy as jnp from jax.lax import while_loop def l_jit_friendly(x, n): initial_state = (x, 0) def cond_fun(state): current_x, current_i = state return jnp.log(current_x) >= n def body_fun(state): current_x, current_i = state next_x = jnp.log(current_x) next_i = current_i + 1 return (next_x, next_i) final_x, final_i = while_loop(cond_fun, body_fun, initial_state) return final_i l_jit = jax.jit(l_jit_friendly) print(l_jit(10000.0, 1.0)) ``` ``` 2 ``` In (C'), using JAX's own control flow primitives like `jax.lax.while_loop` is also a valid and often better solution. ## Be Friendly to JIT The core principle of using `JIT` is **"Pure Functions are necessary"**. <div style="background-color: rgba( 150, 150, 150, 0.7); color: white; padding: 17px; border-radius: 5px;"> ### Definition: Side Effects Changes that a function causes to the external world ### Definition: Pure Function A function is called "pure" if it meets two conditions: 1. The same input always produces the same output. 2. It has no side effects. </div> For example: ```python= g = 0 def impure_function(x): global g g = 10 # Side Effect: Revise a global variable print("Hello") # Side Effect: Print something return x + g ``` When JAX traces this function for the first time, it runs the Python code once. The global variable `g` is changed to `10`, "Hello" is printed, and the resulting blueprint is `input + 10`. However, in subsequent runs, JAX only executes the compiled code for 'input + 10'. It will **not** modify `g` again and will **not** print "Hello" again. This can cause inconsistent behavior and unexpected errors. In other words, `jaxpr` focuses only on recording computations and **ignores** any side effects within the function. This also extends to the concept of **immutability**, where JAX arrays cannot be changed in place. ## Advanced Tricks and Reminders ### The JIT Cache JAX uses a hash to accomplish caching. When JAX is about to compile a function, it computes a hash based on: 1. the code of the function 2. the shape and the type of inputs 3. the values of any static arguments. It then uses this hash value to check if it has compiled this exact configuration before. If so, it reuses the cached compiled code. If not, it compiles a new version, saves it to the cache with its new "fingerprint," and uses it for future computations. This is why using `@jit` on a function defined **inside a loop** is a bad idea, as JAX might recompile it on every iteration if the function object is considered new each time. ### The Timer and Asynchronous Dispatch Let me explain this with a restaurant kitchen analogy: * You (the Python code): The customer sitting at the table. * The CPU (Host): Your brain, which decides what to order. * The GPU/TPU (Device/Accelerator): The restaurant's kitchen, which prepares the food. Its speed can be fast or slow depending on the complexity of the dish. * The JAX Dispatcher: The waiter. #### Scenario 1: Synchronous Dispatch (The Traditional Way) You (the CPU) decide on and order an "appetizer." The waiter takes the order to the kitchen (the GPU). You stop everything and wait, staring at the kitchen door until the waiter brings out the appetizer. After finishing the appetizer, you order a "main course" and repeat steps 2 and 3. The downside: While the kitchen is cooking, you are completely idle and cannot think about the next dish or do anything else. This is very inefficient. #### Scenario 2: Asynchronous Dispatch (The JAX Way) You (the CPU) figure out your entire meal—appetizer, main course, and dessert—and give the complete order to the waiter all at once. The waiter takes all the orders to the kitchen's to-do list (the GPU's command queue). You immediately go back to what you were doing (like chatting with friends, or in Python, preparing the next batch of data for computation), completely ignoring the kitchen. The kitchen works through the orders sequentially, cooking one dish after another without interruption. Only when you actually need a dish on your table (for example, when you want to print() the result of the appetizer or save it to a file) do you pause and ask the waiter, "Is my appetizer ready?" Only then do you start waiting. #### The Technical Explanation * Host: Typically refers to the CPU and the memory it manages. * Device: Refers to accelerators like the GPU or TPU. When you execute a JAX command on the Host (e.g., result = jax_function(x)), JAX does not wait for the Device to complete the computation. It simply places the "computation command" into a Command Queue and immediately returns a future object (or a tracer) to the Python variable result. At that moment, the result object does not contain the actual computed value. It is merely a handle pointing to a piece of memory on the Device, which signifies, "The result will be placed here once the computation is finished." The Host (CPU) can continue executing subsequent Python code, while the Device (GPU) works in the background, pulling commands from the queue and executing them. To summarize, for the sake of efficiency, JAX issues a command and moves on without waiting for the GPU to finish. This highlights why, for performance testing, you must add `.block_until_ready()`. This ensures that you are measuring the complete computation time. ## Checklist for using JIT - [ ] Is the computational workload of the function large enough? - [ ] Is the function a pure function? - [ ] Do any `if/while` loops depend on input values? - [ ] Is `@jit` being used on a function defined inside a loop? - [ ] When benchmarking, am I using `.block_until_ready()`?