jax.make_array_from_callback#
- jax.make_array_from_callback(shape, sharding, data_callback, dtype=None)[source]#
Returns a
jax.Arrayvia data fetched fromdata_callback.data_callbackis used to fetch the data for each addressable shard of the returnedjax.Array. This function must return concrete arrays, meaning thatmake_array_from_callbackhas limited compatibility with JAX transformations likejit()orvmap().- Parameters:
shape (Shape) – Shape of the
jax.Array.sharding (Sharding | Format) – A
Shardinginstance which describes how thejax.Arrayis laid out across devices.data_callback (Callable[[Index | None], ArrayLike]) – Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a
numpy.ndarray.dtype (DTypeLike | None) – The dtype of the output
jax.Array. If not provided, the dtype of the data for the first addressable shard is used. If there are no addressable shards, thedtypeargument must be provided.
- Returns:
A
jax.Arrayvia data fetched fromdata_callback.- Return type:
ArrayImpl
Examples
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> input_shape = (8, 8) >>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) ... >>> def cb(index): ... return global_input_data[index] ... >>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb) >>> arr.addressable_data(0).shape (4, 2)