Device-local array layout control#
The jax.experimental.layout package provides ways to control
how JAX arrays are laid out in device-local memory.
Terminology#
Array layout is tightly coupled with array sharding. Together, a layout and a sharding fully describes how an array’s values are laid out across (distributed) memories. Along these lines, we use the following terminology:
Layout: how an array’s values are laid out within each memory in which they reside (e.g., in the memory of a single device memory). A typical layout specification is a minor-to-major order listing of array dimensions.
Sharding: how an array’s values are distributed across different memory spaces, such as multiple device memories (e.g. described by sharding some dimensions and replicating others).
Format: the pairing of layout and sharding, providing a complete picture of an array’s memory placement.
Types#
There are two Python types that come up when controlling array
layouts: Layout and Format.
The
Layoutclass is used to define the in-memory layout of an array. It has the following key attributes:major_to_minor: A tuple of integers specifying the dimension ordering in memory. For example, for a 2-dimensional array,(0, 1)indicates row-major layout and(1, 0)indicates column-major._tiling: An intentionally hidden, highly experimental, optional attribute to specify a tiled layout.AUTO: A special, static sentinel object that can be used withjax.jitto request that the compiler automatically determine a good layout for a compiled function’s input or output arrays.
The
Formatclass carries both aLayoutand aSharding, with either one taking on a default value when it is not specified. When the layout is explicitly specified, the sharding must be as well.
JAX API functions, such as jax.jit and jax.device_put, accept
Shardings for sharding control or Formats for additional layout
control. They typically do not accept Layout instances directly.
Specifying and reading layouts#
By passing Format objects to jax.jit in place of shardings (in the
in_shardings and out_shardings arguments), you can guide the
compiler’s layout decisions. Similarly you can pass Formats instead
of Shardings to jax.device_put to control the layout of the
resulting array.
Let’s see an example that uses both explicit and automatic layouts (as
in Layout.AUTO). Imagine we have two compiled functions, init_fn
and apply_fn. Say we expect init_fn to be called roughly once, but
apply_fn to be called on the output of init_fn many times, so that
we care much more about the performance of apply_fn. We may want to
have the compiler choose a good layout for apply_fn and constrain
init_fn to produce arrays of such layout. We can do this as follows:
import jax, jax.numpy as jnp
from jax.experimental.layout import Layout, Format
from jax.sharding import SingleDeviceSharding
import numpy as np
def init_fn(x, y):
return x * 2, y * 3
def apply_fn(x, y):
return x[0, :], y[:, 0]
Since apply_fn reads a contiguous column of its second argument y,
it makes sense to lay it out in column-major order (where columns are
stored contiguously). Using Layout.AUTO, we can ask the compiler to
infer good input layouts and see that it indeed chooses to request the
second argument in column-major layout.
shape = (4 * 128, 8 * 128)
duck = jax.ShapeDtypeStruct(shape, jnp.float32)
# Compile the `apply` function with layouts inferred automatically
apply_exe = jax.jit(
apply_fn,
in_shardings=Format(Layout.AUTO),
out_shardings=Format(Layout.AUTO),
).trace(duck, duck).lower().compile()
# Read back the inferred input layout
arg_formats, kwarg_formats = apply_exe.input_formats
assert len(kwarg_formats) == 0
assert arg_formats[0].layout.major_to_minor == (0, 1)
assert arg_formats[1].layout.major_to_minor == (1, 0)
We can then compile init_fn to explicitly match this layout in its
outputs.
init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace(
duck, duck).lower().compile()
assert init_exe.output_formats == arg_formats
Finally we can see how the compiled apply_fn behaves when called
with differently laid out input arrays. The behavior varies with
whether inputs are
committed. As
the following test demonstrates, if the argument arrays are committed,
then the pre-compiled apply_fn requires they match the layout
determined by the compiler above. Meanwhile it accepts uncommitted
arrays of any layout (including, of course, the inferred layout). In
this case, the arrays may be relaid out prior to invoking the compiled
computation.
def test(x, y, msg):
print(f'-- {msg}:')
print('x major_to_minor =', x.format.layout.major_to_minor)
print('y major_to_minor =', y.format.layout.major_to_minor)
try:
apply_exe(x, y)
print('-> `apply` called successfully')
except ValueError as e:
assert 'does not match' in str(e)
print('-> error: mismatched input layouts')
print()
dev = jax.devices()[0]
x1 = y1 = jnp.ones(shape)
test(x1, y1, 'uncommitted with mismatched layout')
x2, y2 = init_exe(x1, y1)
test(x2, y2, 'uncommitted with matching layout')
x3 = jnp.ones(shape)
y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)),
SingleDeviceSharding(dev)))
test(x3, y3, 'committed with matching layout')
x4 = jnp.ones(shape)
y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)),
SingleDeviceSharding(dev)))
test(x4, y4, 'committed with mismatched layout')
-- uncommitted with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> `apply` called successfully
-- uncommitted with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully
-- committed with matching layout:
x major_to_minor = (0, 1)
y major_to_minor = (1, 0)
-> `apply` called successfully
-- committed with mismatched layout:
x major_to_minor = (0, 1)
y major_to_minor = (0, 1)
-> error: mismatched input layouts
Constraining intermediate layouts#
We can also enforce a specific layout on an intermediate value within
a JIT-compiled function using with_layout_constraint:
from jax.experimental.layout import with_layout_constraint
@jax.jit
def f(x):
y = x.T
# Enforce a specific layout on `y`
y = with_layout_constraint(y, Layout(major_to_minor=(0, 1)))
return y * 2
This is analogous to
jax.lax.with_sharding_constraint,
for constraining layouts rather than shardings.