jax.lax.map#
- jax.lax.map(f, xs, *, batch_size=None)[source]#
Map a function over leading array axes.
Like Python’s builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the
vmap()transform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives.When
xsis an array type, the semantics ofmap()are given by this Python implementation:def map(f, xs): return np.stack([f(x) for x in xs])
Like
scan(),map()is implemented in terms of JAX primitives so many of the same advantages over a Python loop apply:xsmay be an arbitrary nested pytree type, and the mapped computation is compiled only once.If
batch_sizeis provided, the computation is executed in batches of that size and parallelized usingvmap(). This can be used as either a more performant version ofmapor as a memory-efficient version ofvmap. If the axis is not divisible by the batch size, the remainder is processed in a separatevmapand concatenated to the result.>>> x = jnp.ones((10, 3, 4)) >>> def f(x): ... print('inner shape:', x.shape) ... return x + 1 >>> y = lax.map(f, x, batch_size=3) inner shape: (3, 4) inner shape: (3, 4) >>> y.shape (10, 3, 4)
In the example above, “inner shape” is printed twice, once while tracing the batched computation and once while tracing the remainder computation.
- Parameters:
f – a Python function to apply element-wise over the first axis or axes of
xs.xs – values over which to map along the leading axis.
batch_size (int | None) – (optional) integer specifying the size of the batch for each step to execute in parallel.
- Returns:
Mapped values.