Quickstart: How to think in JAX#

Open in Colab Open in Kaggle

JAX is a library for array-oriented numerical computation (Ă  la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

This document provides a quick overview of essential JAX features, so you can get started with JAX:

  • JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.

  • JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.

  • JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.

  • JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.

Installation#

JAX can be installed for CPU on Linux, Windows, and macOS directly from the Python Package Index:

pip install jax

or, for NVIDIA GPU:

pip install -U "jax[cuda12]"

For more detailed platform-specific installation information, check out Installation.

JAX vs. NumPy#

Key concepts:

  • JAX provides a NumPy-inspired interface for convenience.

  • Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.

  • Unlike NumPy arrays, JAX arrays are always immutable.

NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides jax.numpy which closely mirrors the NumPy API and provides easy entry into JAX. Almost anything that can be done with numpy can be done with jax.numpy, which is typically imported under the jnp alias:

import jax.numpy as jnp

With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:

import matplotlib.pyplot as plt

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/79a168e290c4923d7f26712aea6aed0738549c914a2e1776f0471e40b4d6e894.png

The code blocks are identical to what you would expect with NumPy, aside from replacing np with jnp, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting.

The arrays themselves are implemented as different Python types:

import numpy as np
import jax.numpy as jnp

x_np = np.linspace(0, 10, 1000)
x_jnp = jnp.linspace(0, 10, 1000)
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib._jax.ArrayImpl

Python’s duck-typing allows JAX arrays and NumPy arrays to be used interchangeably in many places. However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.

Here is an example of mutating an array in NumPy:

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

The equivalent in JAX results in an error, as JAX arrays are immutable:

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

For updating individual elements, JAX provides an indexed update syntax that returns an updated copy:

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

You’ll find a few differences between JAX arrays and NumPy arrays once you begin digging in. See also:

JAX arrays (jax.Array)#

Key concepts:

  • Create arrays using JAX API functions.

  • JAX array objects have a devices attribute that indicates where the array is stored.

  • JAX arrays can be sharded across multiple devices for parallel computation.

The default array implementation in JAX is jax.Array. In many ways it is similar to the numpy.ndarray type that you may be familiar with from the NumPy package, but it has some important differences.

Array creation#

We typically don’t call the jax.Array constructor directly, but rather create arrays via JAX API functions. For example, jax.numpy provides familiar NumPy-style array construction functionality such as jax.numpy.zeros, jax.numpy.linspace, jax.numpy.arange, etc.

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)
True

If you use Python type annotations in your code, jax.Array is the appropriate annotation for jax array objects (see jax.typing for more discussion).

Array devices and sharding#

JAX Array objects have a devices method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:

x.devices()
{CpuDevice(id=0)}

In general, an array may be sharded across multiple devices, in a manner that can be inspected via the sharding attribute:

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

Here the array is on a single device, but in general a JAX array can be sharded across multiple devices, or even multiple hosts. To read more about sharded arrays and parallel computation, refer to Introduction to parallel programming.

Just-in-time compilation with jax.jit#

Key concepts:

  • By default JAX executes operations one at a time, in sequence.

  • Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

  • Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one), with all JAX operations being expressed in terms of XLA. If we have a sequence of operations, we can use the jax.jit function to compile this sequence of operations together using the XLA compiler.

For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of jax.numpy operations:

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

A just-in-time compiled version of the function can be created using the jax.jit transform:

from jax import jit
norm_compiled = jit(norm)

This function returns the same results as the original, up to standard floating-point accuracy:

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case. We can use IPython’s %timeit to quickly benchmark our function, using block_until_ready() to account for JAX’s asynchronous dispatch:

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
218 μs ± 9.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
207 μs ± 3.72 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

That said, jax.jit does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.

For example, this operation can be executed in op-by-op mode:

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

But it returns an error if you attempt to execute it in jit mode:

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10]

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT.

For more on JIT compilation in JAX, check out Just-in-time compilation.

JIT mechanics: tracing and static variables#

Key concepts:

  • JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.

  • Variables that you don’t want to be traced can be marked as static

To use jax.jit effectively, it is useful to understand how it works. Let’s put a few print() statements within a JIT-compiled function and then call the function:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = JitTracer<float32[3,4]>
  y = JitTracer<float32[4]>
  result = JitTracer<float32[3]>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.

These tracer objects are what jax.jit uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the shape and dtype of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr transformation:

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0:f32[]
    d:f32[4] = add b 1.0:f32[]
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_4240/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:

f(1, False)
Array(1, dtype=int32, weak_type=True)

Understanding which values and operations will be static and which will be traced is a key part of using jax.jit effectively.

Taking derivatives with jax.grad#

Key concepts:

  • JAX provides automatic differentiation via the jax.grad transformation.

  • The jax.grad and jax.jit transformations compose and can be mixed arbitrarily.

In addition to transforming functions via JIT compilation, JAX also provides other transformations. One such transformation is jax.grad(), which performs automatic differentiation (autodiff):

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

Let’s verify with finite differences that our result is correct.

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761  0.10502338]

The grad() and jit() transformations compose and can be mixed arbitrarily. For instance, while the sum_logistic function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256

Beyond scalar-valued functions, the jax.jacobian() transformation can be used to compute the full Jacobian matrix for vector-valued functions:

from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]

For more advanced autodiff operations, you can use jax.vjp() for reverse-mode vector-Jacobian products, and jax.jvp() and jax.linearize() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. For example, jax.jvp() and jax.vjp() are used to define the forward-mode jax.jacfwd() and reverse-mode jax.jacrev() for computing Jacobians in forward- and reverse-mode, respectively. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]

This kind of composition produces efficient code in practice; this is more-or-less how JAX’s built-in jax.hessian() function is implemented.

For more on automatic differentiation in JAX, check out Automatic differentiation.

Auto-vectorization with jax.vmap()#

Key concepts:

  • JAX provides automatic vectorization via the jax.vmap transformation.

  • jax.vmap can be composed with jax.jit to produce efficient vectorized code.

Another useful transformation is vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of explicitly looping over function calls, it transforms the function into a natively vectorized version for better performance. When composed with jit(), it can be just as performant as manually rewriting your function to operate over an extra batch dimension.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

from jax import random

key = random.key(1701)
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

The apply_matrix function maps a vector to a vector, but we may want to apply it row-wise across a matrix. We could do this by looping over the batch dimension in Python, but this usually results in poor performance.

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
396 μs ± 1.25 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

A programmer familiar with the jnp.dot function might recognize that apply_matrix can be rewritten to avoid explicit looping, using the built-in batching semantics of jnp.dot:

import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
11.8 μs ± 78.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

However, as functions become more complicated, this kind of manual batching becomes more difficult and error-prone. The vmap() transformation is designed to automatically transform a function into a batch-aware version:

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
15.9 μs ± 59.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

As you would expect, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation.

For more on automatic vectorization in JAX, check out Automatic vectorization.

Pseudorandom numbers#

Key concepts:

  • JAX uses a different model for pseudo random number generation than NumPy.

  • JAX random functions consume a random key that must be split to generate new independent keys.

  • JAX’s random key model is thread-safe and avoids issues with global state.

Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global state, which can be set using numpy.random.seed. Global random state interacts poorly with JAX’s compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random key:

from jax import random

key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]

The key is effectively a stand-in for NumPy’s hidden state object, but we pass it explicitly to jax.random functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.

print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543

The rule of thumb is: never reuse keys (unless you want identical outputs).

In order to generate different and independent samples, you must jax.random.split the key explicitly before passing it to a random function:

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944

Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. jax.random.split is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys.

For more on pseudo random numbers in JAX, see the Pseudorandom numbers tutorial.

This is just a taste of what JAX can do. We’re really excited to see what you do with it!