jax.numpy.argwhere#
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[source]#
Find the indices of nonzero array elements
JAX implementation of
numpy.argwhere().jnp.argwhere(x)is essentially equivalent tojnp.column_stack(jnp.nonzero(x))with special handling for zero-dimensional (i.e. scalar) inputs.Because the size of the output of
argwhereis data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsizeargument, which specifies the size of the leading dimension of the output - it must be specified statically forjnp.argwhereto be compiled with non-static operands. Seejax.numpy.nonzero()for a full discussion ofsizeand its semantics.- Parameters:
a (ArrayLike) – array for which to find nonzero elements
size (int | None) – optional integer specifying statically the number of expected nonzero elements. This must be specified in order to use
argwherewithin JAX transformations likejax.jit(). Seejax.numpy.nonzero()for more information.fill_value (ArrayLike | None) – optional array specifying the fill value when
sizeis specified. Seejax.numpy.nonzero()for more information.
- Returns:
a two-dimensional array of shape
[size, x.ndim]. Ifsizeis not specified as an argument, it is equal to the number of nonzero elements inx.- Return type:
See also
Examples
Two-dimensional array:
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Equivalent computation using
jax.numpy.column_stack()andjax.numpy.nonzero():>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Special case for zero-dimensional (i.e. scalar) inputs:
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)