e41406
    • Create new note
    • Create a note from template
      • Sharing URL Link copied
      • /edit
      • View mode
        • Edit mode
        • View mode
        • Book mode
        • Slide mode
        Edit mode View mode Book mode Slide mode
      • Customize slides
      • Note Permission
      • Read
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Write
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Engagement control Commenting, Suggest edit, Emoji Reply
      • Invitee
    • Publish Note

      Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

      Your note will be visible on your profile and discoverable by anyone.
      Your note is now live.
      This note is visible on your profile and discoverable online.
      Everyone on the web can find and read all notes of this public team.
      See published notes
      Unpublish note
      Please check the box to agree to the Community Guidelines.
      View profile
    • Commenting
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
      • Everyone
    • Suggest edit
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
    • Emoji Reply
    • Enable
    • Versions and GitHub Sync
    • Note settings
    • Engagement control
    • Transfer ownership
    • Delete this note
    • Save as template
    • Insert from template
    • Import from
      • Dropbox
      • Google Drive
      • Gist
      • Clipboard
    • Export to
      • Dropbox
      • Google Drive
      • Gist
    • Download
      • Markdown
      • HTML
      • Raw HTML
Menu Note settings Sharing URL Create Help
Create Create new note Create a note from template
Menu
Options
Versions and GitHub Sync Engagement control Transfer ownership Delete this note
Import from
Dropbox Google Drive Gist Clipboard
Export to
Dropbox Google Drive Gist
Download
Markdown HTML Raw HTML
Back
Sharing URL Link copied
/edit
View mode
  • Edit mode
  • View mode
  • Book mode
  • Slide mode
Edit mode View mode Book mode Slide mode
Customize slides
Note Permission
Read
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Write
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Engagement control Commenting, Suggest edit, Emoji Reply
Invitee
Publish Note

Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

Your note will be visible on your profile and discoverable by anyone.
Your note is now live.
This note is visible on your profile and discoverable online.
Everyone on the web can find and read all notes of this public team.
See published notes
Unpublish note
Please check the box to agree to the Community Guidelines.
View profile
Engagement control
Commenting
Permission
Disabled Forbidden Owners Signed-in users Everyone
Enable
Permission
  • Forbidden
  • Owners
  • Signed-in users
  • Everyone
Suggest edit
Permission
Disabled Forbidden Owners Signed-in users Everyone
Enable
Permission
  • Forbidden
  • Owners
  • Signed-in users
Emoji Reply
Enable
Import from Dropbox Google Drive Gist Clipboard
   owned this note    owned this note      
Published Linked with GitHub
Subscribed
  • Any changes
    Be notified of any changes
  • Mention me
    Be notified of mention me
  • Unsubscribe
Subscribe
# Quickstart: How to think in JAX JAX is a library for array-oriented numerical computation (like [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research. This note provides a quick overview of essential JAX features, so you can get started with JAX: * JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings. * JAX features built-in Just-In-Time (JIT) compilation via [Open XLA](https://github.com/openxla), an open-source machine learning compiler ecosystem. * JAX functions support efficient evaluation of gradients via its automatic differentiation transformations. * JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs. ## Installation JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/): ```bash pip install jax ``` or, for NVIDIA GPU: ```bash pip install -U "jax[cuda12]" ``` ### MacBook GPU Support For Apple Silicon MacBooks (M1/M2/M3), you can try to enable Metal GPU acceleration: ```bash pip install -U "jax[metal]" ``` Then check available devices and enable Metal: ```python= import os os.environ['JAX_PLATFORM_NAME'] = 'metal' # Set before importing JAX import jax print("Available devices:", jax.devices()) ``` **Note:** Metal support is experimental. JAX will fall back to CPU if Metal acceleration fails. For more detailed platform-specific installation information, check out [Installation](https://docs.jax.dev/en/latest/installation.html). ## JAX vs. NumPy **Key concepts:** * JAX provides a NumPy-inspired interface for convenience. * Through [duck-typing](https://en.wikipedia.org/wiki/Duck_typing), JAX arrays can often be used as drop-in replacements for NumPy arrays. * Unlike NumPy arrays, JAX arrays are always immutable. NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides [jax.numpy](https://docs.jax.dev/en/latest/jax.numpy.html) which closely mirrors the NumPy API and provides easy entry into JAX. Almost anything that can be done with numpy can be done with jax.numpy, which is typically imported under the jnp alias: ```python= import jax.numpy as jnp ``` With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods: ```python= import matplotlib.pyplot as plt x_jnp = jnp.linspace(0, 10, 1000) y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp) plt.plot(x_jnp, y_jnp); ``` ![image](https://hackmd.io/_uploads/H1E157Bieg.png) The code blocks are identical to what you would expect with NumPy, aside from replacing `np` with `jnp`, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting. The arrays themselves are implemented as different Python types: ```python= import numpy as np import jax.numpy as jnp x_np = np.linspace(0, 10, 1000) x_jnp = jnp.linspace(0, 10, 1000) ``` ```python=+ type(x_np) ``` ``` numpy.ndarray ``` ```python=+ type(x_jnp) ``` ``` jaxlib._jax.ArrayImpl ``` Python's duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places. However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed. ### Immutable Arrays Here is an example of mutating an array in NumPy: ```python= # NumPy: mutable arrays x = np.arange(10) x[0] = 10 print(x) ``` ``` [10 1 2 3 4 5 6 7 8 9] ``` The equivalent in JAX results in a `TypeError` because JAX arrays are immutable: ```python %xmode minimal ``` ``` Exception reporting mode: Minimal ``` ```python= # JAX: immutable arrays x = jnp.arange(10) x[0] = 10 ``` ``` TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ``` For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```python=+ y = x.at[0].set(10) print(x) print(y) ``` ``` [0 1 2 3 4 5 6 7 8 9] [10 1 2 3 4 5 6 7 8 9] ``` You'll find a few differences between JAX arrays and NumPy arrays once you begin digging in. See also: * [Key concepts](https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array) for an introduction to the key concepts of JAX, such as transformations, tracing, jaxprs and pytrees. * [🔪 JAX - The Sharp Bits 🔪](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html) for common gotchas when using JAX. ## JAX arrays (`jax.Array`) **Key concepts:** * Create arrays using JAX API functions. * JAX array objects have a `devices` attribute that indicates where the array is stored. * JAX arrays can be sharded across multiple devices for parallel computation. The default array implementation in JAX is [`jax.Array`](https://docs.jax.dev/en/latest/_autosummary/jax.Array.html#jax.Array). In many ways it is similar to the [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) type that you may be familiar with from the NumPy package, but it has some important differences. ### Array creation We typically don't call the `jax.Array` constructor directly, but rather create arrays via JAX API functions. For example, [`jax.numpy`](https://docs.jax.dev/en/latest/jax.numpy.html#module-jax.numpy) provides familiar NumPy-style array construction functionality such as `jax.numpy.zeros`, `jax.numpy.linspace`, `jax.numpy.arange`, etc. ```python= import jax import jax.numpy as jnp x = jnp.arange(5) isinstance(x, jax.Array) ``` ``` True ``` If you use Python type annotations in your code, `jax.Array` is the appropriate annotation for jax array objects (see `jax.typing` for more discussion). If you use Python type annotations in your code, `jax.Array` is the appropriate annotation for jax array objects (see [`jax.typing`](https://docs.jax.dev/en/latest/jax.typing.html#module-jax.typing) for more discussion). ### Array devices and sharding JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device: ```python=+ x.devices() ``` ``` {CpuDevice(id=0)} ``` In general, an array may be [sharded](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) across multiple devices, in a manner that can be inspected via the `sharding` attribute: ```python=+ x.sharding ``` ``` SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device) ``` Here the array is on a single device, but in general a JAX array can be sharded across multiple devices, or even multiple hosts. To read more about sharded arrays and parallel computation, refer to [Introduction to parallel programming](https://docs.jax.dev/en/latest/sharded-computation.html). ## Just-in-time compilation with `jax.jit` **Key concepts:** * By default JAX executes operations one at a time, in sequence. * Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once. * Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time. JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one), with all JAX operations being expressed in terms of XLA. If we have a sequence of operations, we can use the [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) function to compile this sequence of operations together using the XLA compiler. For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of `jax.numpy` operations: ```python= import jax.numpy as jnp def norm(X): X = X - X.mean(0) return X / X.std(0) ``` A just-in-time compiled version of the function can be created using the `jax.jit` transform: ```python=+ from jax import jit norm_compiled = jit(norm) ``` This function returns the same results as the original, up to standard floating-point accuracy: ```python=+ np.random.seed(1701) X = jnp.array(np.random.rand(10000, 10)) np.allclose(norm(X), norm_compiled(X), atol=1E-6) ``` ``` True ``` But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case. We can use IPython's `%timeit` to quickly benchmark our function, using `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html): ```python=+ %timeit norm(X).block_until_ready() %timeit norm_compiled(X).block_until_ready() ``` ``` 266 μs ± 3.94 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 260 μs ± 24 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` That said, `jax.jit` does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation. For example, this operation can be executed in op-by-op mode: ```python= def get_negatives(x): return x[x < 0] x = jnp.array(np.random.randn(10)) get_negatives(x) ``` ``` Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32) ``` But it returns an error if you attempt to execute it in jit mode: ```python= jit(get_negatives)(x) ``` ``` NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10] See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError ``` This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT. For more on JIT compilation in JAX, check out [Just-in-time compilation](https://docs.jax.dev/en/latest/jit-compilation.html). ### JIT mechanics: tracing and static variables **Key concepts:** * JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type. * Variables that you don't want to be traced can be marked as static. To use `jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function: ```python= @jit def f(x, y): print("Running f():") print(f" {x = }") print(f" {y = }") result = jnp.dot(x + 1, y + 1) print(f" {result = }") return result x = np.random.randn(3, 4) y = np.random.randn(4) f(x, y) ``` ``` Running f(): x = JitTracer<float32[3,4]> y = JitTracer<float32[4]> result = JitTracer<float32[3]> ``` ``` Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32) ``` Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them. These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code. When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python: ```python=+ x2 = np.random.randn(3, 4) y2 = np.random.randn(4) f(x2, y2) ``` ``` Array([1.4344584, 4.3004413, 7.9897013], dtype=float32) ``` The extracted sequence of operations is encoded in a JAX expression, or [jaxpr](https://docs.jax.dev/en/latest/key-concepts.html#jaxprs) for short. You can view the jaxpr using the `jax.make_jaxpr` transformation: ```python= from jax import make_jaxpr def f(x, y): return jnp.dot(x + 1, y + 1) make_jaxpr(f)(x, y) ``` ``` { lambda ; a:f32[3,4] b:f32[4]. let c:f32[3,4] = add a 1.0:f32[] d:f32[4] = add b 1.0:f32[] e:f32[3] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] c d in (e,) } ``` Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails: ```python=+ @jit def f(x, neg): return -x if neg else x f(1, True) ``` ``` TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function f at /var/folders/sm/ws6rqp1946ng_hlfm7zkh6s40000gn/T/ipykernel_1803/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg. See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError ``` If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation: ```python=+ from functools import partial @partial(jit, static_argnums=(1,)) def f(x, neg): return -x if neg else x f(1, True) ``` ``` Array(-1, dtype=int32, weak_type=True) ``` Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected: ```python=+ f(1, False) ``` ``` Array(1, dtype=int32, weak_type=True) ``` Understanding which values and operations will be static and which will be traced is a key part of using `jax.jit` effectively. ## Taking derivatives with `jax.grad` **Key concepts:** * JAX provides automatic differentiation via the `jax.grad` transformation. * The `jax.grad` and `jax.jit` transformations compose and can be mixed arbitrarily. In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation is [`jax.grad`](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), which performs [automatic differentiation (autodiff)](https://en.wikipedia.org/wiki/Automatic_differentiation): ```python= from jax import grad def sum_logistic(x): return jnp.sum(1.0 / (1.0 + jnp.exp(-x))) x_small = jnp.arange(3.) derivative_fn = grad(sum_logistic) print(derivative_fn(x_small)) ``` ``` [0.25 0.19661197 0.10499357] ``` Let's verify with finite differences that our result is correct. ```python=+ def first_finite_differences(f, x, eps=1E-3): return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))]) print(first_finite_differences(sum_logistic, x_small)) ``` ``` [0.24998187 0.1964569 0.10502338] ``` The `jax.grad` and `jax.jit` transformations compose and can be mixed arbitrarily. For instance, while the `sum_logistic` function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further: ```python=+ print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) ``` ``` -0.0353256 ``` Beyond scalar-valued functions, the [`jax.jacobian`](https://docs.jax.dev/en/latest/_autosummary/jax.jacobian.html) transformation can be used to compute the full Jacobian matrix for vector-valued functions: ```python=+ from jax import jacobian print(jacobian(jnp.exp)(x_small)) ``` ``` [[1. 0. 0. ] [0. 2.7182817 0. ] [0. 0. 7.389056 ]] ``` For more advanced autodiff operations, you can use [`jax.vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.vjp.html) for reverse-mode vector-Jacobian products, and [`jax.jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [`jax.linearize`](https://docs.jax.dev/en/latest/_autosummary/jax.linearize.html) for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. For example, `jax.jvp` and `jax.vjp` are used to define the forward-mode [`jax.jacfwd`](https://docs.jax.dev/en/latest/_autosummary/jax.jacfwd.html) and reverse-mode [`jax.jacrev`](https://docs.jax.dev/en/latest/_autosummary/jax.jacrev.html) for computing Jacobians in forward- and reverse-mode, respectively. Here's one way to compose them to make a function that efficiently computes full Hessian matrices: ```python=+ from jax import jacfwd, jacrev def hessian(fun): return jit(jacfwd(jacrev(fun))) print(hessian(sum_logistic)(x_small)) ``` ``` from jax import jacobian print(jacobian(jnp.exp)(x_small)) ``` This kind of composition produces efficient code in practice; this is more-or-less how JAX's built-in [`jax.hessian`](https://docs.jax.dev/en/latest/_autosummary/jax.hessian.html) function is implemented. For more on automatic differentiation in JAX, check out [Automatic differentiation](https://docs.jax.dev/en/latest/automatic-differentiation.html). ## Auto-vectorization with `jax.vmap` **Key concepts:** * JAX provides automatic vectorization via the `jax.vmap` transformation. * `jax.vmap` can be composed with `jax.jit` to produce efficient vectorized code. Another useful transformation is [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping over function calls, it transforms the function into a natively vectorized version for better performance. When composed with `jax.jit`, it can be just as performant as manually rewriting your function to operate over an extra batch dimension. We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using `jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions. ```python= from jax import random key = random.key(1701) key1, key2 = random.split(key) mat = random.normal(key1, (150, 100)) batched_x = random.normal(key2, (10, 100)) def apply_matrix(x): return jnp.dot(mat, x) ``` The `apply_matrix` function maps a vector to a vector, but we may want to apply it row-wise across a matrix. We could do this by looping over the batch dimension in Python, but this usually results in poor performance. ```python=+ def naively_batched_apply_matrix(v_batched): return jnp.stack([apply_matrix(v) for v in v_batched]) print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready() ``` ``` Naively batched 347 μs ± 33.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) ``` A programmer familiar with the `jnp.dot` function might recognize that apply_matrix can be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: ```python=+ import numpy as np @jit def batched_apply_matrix(batched_x): return jnp.dot(batched_x, mat.T) np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) print('Manually batched') %timeit batched_apply_matrix(batched_x).block_until_ready() ``` ``` Manually batched 20 μs ± 1.23 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) ``` However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. The `jax.vmap` transformation is designed to automatically transform a function into a batch-aware version: ```python=+ from jax import vmap @jit def vmap_batched_apply_matrix(batched_x): return vmap(apply_matrix)(batched_x) np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4) print('Auto-vectorized with vmap') %timeit vmap_batched_apply_matrix(batched_x).block_until_ready() ``` ``` Auto-vectorized with vmap 22.9 μs ± 447 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) ``` As you would expect, `jax.vmap` can be arbitrarily composed with `jax.jit`, `jax.grad`, and any other JAX transformation. For more on automatic vectorization in JAX, check out [Automatic vectorization](https://docs.jax.dev/en/latest/automatic-vectorization.html). ## Pseudorandom numbers **Key concepts:** * JAX uses a different model for pseudo random number generation than NumPy. * JAX random functions consume a random `key` that must be split to generate new independent keys. * JAX's random key model is thread-safe and avoids issues with global state. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using [`numpy.random.seed`](https://numpy.org/doc/stable/reference/random/generated/numpy.random.seed.html). Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: ```python= from jax import random key = random.key(43) print(key) ``` ``` Array((), dtype=key<fry>) overlaying: [ 0 43] ``` The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to [`jax.random`](https://docs.jax.dev/en/latest/jax.random.html) functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. ```python=+ print(random.normal(key)) print(random.normal(key)) ``` ``` 0.07520543 0.07520543 ``` **The rule of thumb is: never reuse keys (unless you want identical outputs).** In order to generate different and independent samples, you must [`jax.random.split`](https://docs.jax.dev/en/latest/_autosummary/jax.random.split.html) the key explicitly before passing it to a random function: ```python=+ key = random.key(42) # Start with a new key for a fresh example for i in range(3): key, subkey = random.split(key) val = random.normal(subkey) print(f"draw {i}: {val}") ``` ``` draw 0: -1.9133632183074951 draw 1: -1.4749839305877686 draw 2: -0.36703771352767944 ``` Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. `jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. For more on pseudo random numbers in JAX, see the [Pseudorandom numbers tutorial](https://docs.jax.dev/en/latest/random-numbers.html). ## Reference * JAX developers. (n.d.). Thinking in JAX. JAX documentation. https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html

Import from clipboard

Paste your markdown or webpage here...

Advanced permission required

Your current role can only read. Ask the system administrator to acquire write and comment permission.

This team is disabled

Sorry, this team is disabled. You can't edit this note.

This note is locked

Sorry, only owner can edit this note.

Reach the limit

Sorry, you've reached the max length this note can be.
Please reduce the content or divide it to more notes, thank you!

Import from Gist

Import from Snippet

or

Export to Snippet

Are you sure?

Do you really want to delete this note?
All users will lose their connection.

Create a note from template

Create a note from template

Oops...
This template has been removed or transferred.
Upgrade
All
  • All
  • Team
No template.

Create a template

Upgrade

Delete template

Do you really want to delete this template?
Turn this template into a regular note and keep its content, versions, and comments.

This page need refresh

You have an incompatible client version.
Refresh to update.
New version available!
See releases notes here
Refresh to enjoy new features.
Your user state has changed.
Refresh to load new user state.

Sign in

Forgot password

or

By clicking below, you agree to our terms of service.

Sign in via Facebook Sign in via Twitter Sign in via GitHub Sign in via Dropbox Sign in with Wallet
Wallet ( )
Connect another wallet

New to HackMD? Sign up

Help

  • English
  • 中文
  • Français
  • Deutsch
  • 日本語
  • Español
  • Català
  • Ελληνικά
  • Português
  • italiano
  • Türkçe
  • Русский
  • Nederlands
  • hrvatski jezik
  • język polski
  • Українська
  • हिन्दी
  • svenska
  • Esperanto
  • dansk

Documents

Help & Tutorial

How to use Book mode

Slide Example

API Docs

Edit in VSCode

Install browser extension

Contacts

Feedback

Discord

Send us email

Resources

Releases

Pricing

Blog

Policy

Terms

Privacy

Cheatsheet

Syntax Example Reference
# Header Header 基本排版
- Unordered List
  • Unordered List
1. Ordered List
  1. Ordered List
- [ ] Todo List
  • Todo List
> Blockquote
Blockquote
**Bold font** Bold font
*Italics font* Italics font
~~Strikethrough~~ Strikethrough
19^th^ 19th
H~2~O H2O
++Inserted text++ Inserted text
==Marked text== Marked text
[link text](https:// "title") Link
![image alt](https:// "title") Image
`Code` Code 在筆記中貼入程式碼
```javascript
var i = 0;
```
var i = 0;
:smile: :smile: Emoji list
{%youtube youtube_id %} Externals
$L^aT_eX$ LaTeX
:::info
This is a alert area.
:::

This is a alert area.

Versions and GitHub Sync
Get Full History Access

  • Edit version name
  • Delete

revision author avatar     named on  

More Less

Note content is identical to the latest version.
Compare
    Choose a version
    No search result
    Version not found
Sign in to link this note to GitHub
Learn more
This note is not linked with GitHub
 

Feedback

Submission failed, please try again

Thanks for your support.

On a scale of 0-10, how likely is it that you would recommend HackMD to your friends, family or business associates?

Please give us some advice and help us improve HackMD.

 

Thanks for your feedback

Remove version name

Do you want to remove this version name and description?

Transfer ownership

Transfer to
    Warning: is a public team. If you transfer note to this team, everyone on the web can find and read this note.

      Link with GitHub

      Please authorize HackMD on GitHub
      • Please sign in to GitHub and install the HackMD app on your GitHub repo.
      • HackMD links with GitHub through a GitHub App. You can choose which repo to install our App.
      Learn more  Sign in to GitHub

      Push the note to GitHub Push to GitHub Pull a file from GitHub

        Authorize again
       

      Choose which file to push to

      Select repo
      Refresh Authorize more repos
      Select branch
      Select file
      Select branch
      Choose version(s) to push
      • Save a new version and push
      • Choose from existing versions
      Include title and tags
      Available push count

      Pull from GitHub

       
      File from GitHub
      File from HackMD

      GitHub Link Settings

      File linked

      Linked by
      File path
      Last synced branch
      Available push count

      Danger Zone

      Unlink
      You will no longer receive notification when GitHub file changes after unlink.

      Syncing

      Push failed

      Push successfully