Key concepts#

This section briefly introduces some key concepts of the JAX package.

Transformations#

Along with functions to operate on arrays, JAX includes a number of transformations which operate on JAX functions. These include

as well as several others. Transformations accept a function as an argument, and return a new transformed function. For example, here’s how you might JIT-compile a simple SELU function:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

Often you’ll see transformations applied using Python’s decorator syntax for convenience:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

Tracing#

The magic behind transformations is the notion of a Tracer. Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order to extract the sequence of operations that the function encodes.

You can see this by printing any array value within transformed JAX code; for example:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
JitTracer<int32[5]>

The value printed is not the array x, but a Tracer instance that represents essential attributes of x, such as its shape and dtype. By executing the function with traced values, JAX can determine the sequence of operations encoded by the function before those operations are actually executed: transformations like jit(), vmap(), and grad() can then map this sequence of input operations to a transformed sequence of operations.

Static vs traced operations#

Just as values can be either static or traced, operations can be static or traced. Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.

This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 9
      6   return x.reshape(jnp.array(x.shape).prod())
      8 x = jnp.ones((2, 3))
----> 9 f(x)

    [... skipping hidden 13 frame]

Cell In[4], line 6, in f(x)
      4 @jit
      5 def f(x):
----> 6   return x.reshape(jnp.array(x.shape).prod())

    [... skipping hidden 2 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:457, in _compute_newshape(arr, newshape)
    455 except:
    456   newshape = [newshape]
--> 457 newshape = core.canonicalize_shape(newshape)  # type: ignore[arg-type]
    458 neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
    459 if len(neg1s) > 1:

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/core.py:1954, in canonicalize_shape(shape, context)
   1952 except TypeError:
   1953   pass
-> 1954 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [JitTracer<int32[]>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_1881/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = reduce_prod[axes=(0,)] b
    from line /tmp/ipykernel_1881/1983583872.py:6:19 (f)

This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = JitTracer<float32[2,3]>
x.shape = (2, 3)
jnp.array(x.shape).prod() = JitTracer<int32[]>

Notice that although x is traced, x.shape is a static value. However, when we use jnp.array and jnp.prod on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape() that requires a static input (recall: array shapes must be static).

A useful pattern is to use numpy for operations that should be static (i.e. done at compile-time), and use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

For this reason, a standard convention in JAX programs is to import numpy as np and import jax.numpy as jnp so that both interfaces are available for finer control over whether operations are performed in a static manner (with numpy, once at compile-time) or a traced manner (with jax.numpy, optimized at run-time).

Jaxprs#

JAX has its own intermediate representation for sequences of operations, known as a jaxpr. A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations.

For example, consider the selu function we defined above:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

We can use the jax.make_jaxpr() utility to convert this function into a jaxpr given a particular input:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0:f32[]
    c:f32[5] = exp a
    d:f32[5] = mul 1.67:f32[] c
    e:f32[5] = sub d 1.67:f32[]
    f:f32[5] = jit[
      name=_where
      jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
          f:f32[5] = select_n b e a
        in (f,) }
    ] b a e
    g:f32[5] = mul 1.05:f32[] f
  in (g,) }

Comparing this to the Python function definition, we see that it encodes the precise sequence of operations that the function represents. We’ll go into more depth about jaxprs later in JAX internals: The jaxpr language.

Pytrees#

JAX functions and transformations fundamentally operate on arrays, but in practice it is convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the pytree abstraction to treat such collections in a uniform manner.

Here are some examples of objects that can be treated as pytrees:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX has a number of general-purpose utilities for working with PyTrees; for example the functions jax.tree.map() can be used to map a function to every leaf in a tree, and jax.tree.reduce() can be used to apply a reduction across the leaves in a tree.

You can learn more in the Working with pytrees tutorial.

JAX API layering: NumPy, lax & XLA#

All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler. If you look at the source of jax.numpy, you’ll see that all the operations are eventually expressed in terms of functions defined in jax.lax. While jax.numpy is a high-level wrapper that provides a familiar interface, you can think of jax.lax as a stricter, but often more powerful, lower-level API for working with multi-dimensional arrays.

For example, while jax.numpy will implicitly promote arguments to allow operations between mixed data types, jax.lax will not:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[13], line 2
      1 from jax import lax
----> 2 lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/lax/lax.py:1194, in add(x, y)
   1174 r"""Elementwise addition: :math:`x + y`.
   1175 
   1176 This function lowers directly to the `stablehlo.add`_ operation.
   (...)   1191 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add
   1192 """
   1193 x, y = core.standard_insert_pvary(x, y)
-> 1194 return add_p.bind(x, y)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/core.py:606, in Primitive.bind(self, *args, **params)
    604 def bind(self, *args, **params):
    605   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 606   return self._true_bind(*args, **params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/core.py:622, in Primitive._true_bind(self, *args, **params)
    620 trace_ctx.set_trace(eval_trace)
    621 try:
--> 622   return self.bind_with_trace(prev_trace, args, params)
    623 finally:
    624   trace_ctx.set_trace(prev_trace)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/core.py:632, in Primitive.bind_with_trace(self, trace, args, params)
    629   with set_current_trace(trace):
    630     return self.to_lojax(*args, **params)  # type: ignore
--> 632 return trace.process_primitive(self, args, params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/core.py:1156, in EvalTrace.process_primitive(self, primitive, args, params)
   1154 args = map(full_lower, args)
   1155 check_eval_args(args)
-> 1156 return primitive.impl(*args, **params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/dispatch.py:91, in apply_primitive(prim, *args, **params)
     89 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     90 try:
---> 91   outs = fun(*args)
     92 finally:
     93   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 26 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/30534/lib/python3.12/site-packages/jax/_src/lax/lax.py:8758, in check_same_dtypes(name, *avals)
   8756   equiv = _JNP_FUNCTION_EQUIVALENTS[name]
   8757   msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
-> 8758 raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))

TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

If using jax.lax directly, you’ll have to do type promotion explicitly in such cases:

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

Along with this strictness, jax.lax also provides efficient APIs for some more general operations than are supported by NumPy.

For example, consider a 1D convolution, which can be expressed in NumPy this way:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

Under the hood, this NumPy operation is translated to a much more general convolution implemented by lax.conv_general_dilated:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See Convolutions in JAX for more detail on JAX convolutions).

At their heart, all jax.lax operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by XLA:ConvWithGeneralPadding. Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation.