jax.make_array_from_single_device_arrays#
- jax.make_array_from_single_device_arrays(shape, sharding, arrays, *, dtype=None)[source]#
- Returns a
jax.Arrayfrom a sequence ofjax.Arrays each on a single device. Every device in input
sharding's mesh must have an array inarrayss.
- Parameters:
shape (Shape) – Shape of the output
jax.Array. This conveys information already included withshardingandarraysand serves as a double check.sharding (Sharding) – Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices.
arrays (Sequence[basearray.Array]) – list or tuple of
jax.Arrays that are each single device addressable.len(arrays)must equallen(sharding.addressable_devices)and the shape of each array must be the same. For multiprocess code, each process will call with a differentarraysargument that corresponds to that processes’ data. These arrays are commonly created viajax.device_put.dtype (DTypeLike | None) – The dtype of the output
jax.Array. If not provided, the dtype of the first array inarraysis used. Ifarraysis empty, thedtypeargument must be provided.
- Returns:
- A global
jax.Array, sharded assharding, with shape equal toshape, and with per-device contents matching
arrays.
- A global
- Return type:
ArrayImpl
Examples
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> mesh_rows = 2 >>> mesh_cols = jax.device_count() // 2 ... >>> global_shape = (8, 8) >>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y')) >>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) >>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape) ... >>> arrays = [ ... jax.device_put(inp_data[index], d) ... for d, index in sharding.addressable_devices_indices_map(global_shape).items()] ... >>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays) >>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
For cases where you have a local array and want to convert it to a global jax.Array, use
jax.make_array_from_process_local_data.- Returns a