jax.numpy.indices#

jax.numpy.indices(dimensions, dtype=None, sparse=False)[source]#

Generate arrays of grid indices.

JAX implementation of numpy.indices().

Parameters:
  • dimensions (Sequence[int]) – the shape of the grid.

  • dtype (DTypeLike | None) – the dtype of the indices (defaults to integer).

  • sparse (bool) – if True, then return sparse indices. Default is False, which returns dense indices.

Returns:

An array of shape (len(dimensions), *dimensions) If sparse is False, or a sequence of arrays of the same length as dimensions if sparse is True.

Return type:

Array | tuple[Array, …]

See also

Examples

>>> jnp.indices((2, 3))
Array([[[0, 0, 0],
        [1, 1, 1]],

       [[0, 1, 2],
        [0, 1, 2]]], dtype=int32)
>>> jnp.indices((2, 3), sparse=True)
(Array([[0],
       [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))