jax.shard_map#
- jax.shard_map(f=None, /, *, out_specs, axis_names={}, in_specs=None, mesh=None, check_vma=True)[source]#
Map a function over shards of data using a mesh of devices.
See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
- Parameters:
f – callable to be mapped. Each application of
f, or “instance” off, takes as input a shard of the mapped-over arguments and produces a shard of the output.mesh (Mesh | AbstractMesh | None) – (optional, default None) a
jax.sharding.Meshrepresenting the array of devices over which to shard the data and on which to execute instances off. The names of theMeshcan be used in collective communication operations inf. If mesh is None, it will be inferred from the context which can be set via jax.sharding.use_mesh context manager.in_specs (Specs | None) – (optional, default None) a pytree with
jax.sharding.PartitionSpecinstances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar tojax.sharding.NamedSharding, eachPartitionSpecrepresents how the corresponding argument (or subtree of arguments) should be sharded along the named axes ofmesh. In eachPartitionSpec, mentioning ameshaxis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. IfNone, all mesh axes must be of type Explicit, in which case the in_specs are inferred from the argument types.out_specs (Specs) – a pytree with
PartitionSpecinstances as leaves, with a tree structure that is a tree prefix of the output off. EachPartitionSpecrepresents how the corresponding output shards should be concatenated. In eachPartitionSpec, mentioning ameshaxis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis; not mentioning ameshaxis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.axis_names (Set[AxisName]) – (optional, default set()) set of axis names from
meshover which the functionfis manual. If empty,f, is manual over all mesh axes.check_vma (bool) – (optional) boolean (default True) representing whether to enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in
out_specsare consistent with how the outputs offare replicated.
- Returns:
A callable representing a mapped version of
f, which accepts positional arguments corresponding to those offand produces output corresponding to that off.