jax.vmap#
- jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)[source]#
Vectorizing map. Creates a function which maps
funover argument axes.- Parameters:
fun (F) – Function to be mapped over additional axes.
in_axes (int | None | Sequence[Any]) –
An integer, None, or sequence of values specifying which input array axes to map over.
If each positional argument to
funis an array, thenin_axescan be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments tofun. An integer orNoneindicates which array axis to map over for all arguments (withNoneindicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range[-ndim, ndim)for each array, wherendimis the number of dimensions (axes) of the corresponding input array.If the positional arguments to
funare container (pytree) types,in_axesmust be a sequence with length equal to the number of positional arguments tofun, and for each argument the corresponding element ofin_axescan be a container with a matching pytree structure specifying the mapping of its container elements. In other words,in_axesmust be a container tree prefix of the positional argument tuple passed tofun. See this link for more detail: https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytreesEither
axis_sizemust be provided explicitly, or at least one positional argument must havein_axesnot None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0).
See below for examples.
out_axes (Any) – An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None
out_axesspecification. Axis integers must be in the range[-ndim, ndim)for each output array, wherendimis the number of dimensions (axes) of the array returned by thevmap()-ed function, which is one more than the number of dimensions (axes) of the corresponding array returned byfun.axis_name (AxisName | None) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.
axis_size (int | None) – Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments.
spmd_axis_name (AxisName | tuple[AxisName, ...] | None)
- Returns:
Batched/vectorized version of
funwith arguments that correspond to those offun, but with extra array axes at positions indicated byin_axes, and a return value that corresponds to that offun, but with extra array axes at positions indicated byout_axes.- Return type:
F
For example, we can implement a matrix-matrix product using a vector dot product:
>>> import jax.numpy as jnp >>> >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
Here we use
[a,b]to indicate an array with shape (a,b). Here are some variants:>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axesto specify which axes of the container elements to map over:>>> A, B, C, D = 2, 3, 4, 5 >>> x = jnp.ones((A, B)) >>> y = jnp.ones((B, C)) >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = jnp.ones((K, A, B)) # batch axis in different locations >>> y = jnp.ones((B, K, C)) >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) (6, 2, 5)
Here’s another example using container types in
in_axes, this time a dictionary, to specify the elements of the container to map over:>>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x >>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x) >>> print(out) [1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example, the function below returns a pair with the first element mapped and the second unmapped. Only for unmapped results we can specify
out_axesto beNone(to keep it unmapped).>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (Array([4., 5.], dtype=float32), 8.0)
If the
out_axesis specified for an unmapped result, the result is broadcast across the mapped axis:>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
If the
out_axesis specified for a mapped result, the result is transposed accordingly.Finally, here’s an example using
axis_nametogether with collectives:>>> xs = jnp.arange(3. * 4.).reshape(3, 4) >>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs)) [[12. 15. 18. 21.] [12. 15. 18. 21.] [12. 15. 18. 21.]]
See the
jax.pmap()docstring for more examples involving collectives.