jax.experimental.shard_map.shard_map#
- jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[source]#
Map a function over shards of data.
Note
shard_mapis an experimental API, and still subject to change. For an introduction to sharded data, refer to Introduction to parallel programming. For a more in-depth look at usingshard_map, refer to SPMD multi-device parallelism with shard_map.- Parameters:
f (Callable) – callable to be mapped. Each application of
f, or “instance” off, takes as input a shard of the mapped-over arguments and produces a shard of the output.mesh (Mesh | AbstractMesh) – a
jax.sharding.Meshrepresenting the array of devices over which to shard the data and on which to execute instances off. The names of theMeshcan be used in collective communication operations inf. This is typically created by a utility function likejax.experimental.mesh_utils.create_device_mesh().in_specs (Any) – a pytree with
PartitionSpecinstances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar toNamedSharding, eachPartitionSpecrepresents how the corresponding argument (or subtree of arguments) should be sharded along the named axes ofmesh. In eachPartitionSpec, mentioning ameshaxis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded.out_specs (Any) – a pytree with
PartitionSpecinstances as leaves, with a tree structure that is a tree prefix of the output off. EachPartitionSpecrepresents how the corresponding output shards should be concatenated. In eachPartitionSpec, mentioning ameshaxis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis. Not mentioning ameshaxis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.check_rep (bool) – If True (default) enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in
out_specsare consistent with how the outputs offare replicated. Must be set False if using a Pallas kernel inf.auto (frozenset[Hashable]) – (experimental) an optional set of axis names from
meshover which we do not shard the data or map the function, but rather we allow the compiler to control sharding. These names cannot be used inin_specs,out_specs, or in communication collectives inf.
- Returns:
A callable that applies the input function
facross data sharded according to themeshandin_specs.
Examples
For examples, refer to Introduction to parallel programming or SPMD multi-device parallelism with shard_map.