jax.numpy.nonzero#
- jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]#
Return indices of nonzero elements of an array.
JAX implementation of
numpy.nonzero().Because the size of the output of
nonzerois data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.nonzeroto be used within JAX’s transformations.- Parameters:
a (ArrayLike) – N-dimensional array.
size (int | None) – optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specified
size, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded withfill_value, which defaults to zero.fill_value (None | ArrayLike | tuple[ArrayLike, ...]) – optional padding value when
sizeis specified. Defaults to 0.
- Returns:
Tuple of JAX Arrays of length
a.ndim, containing the indices of each nonzero value.- Return type:
See also
Examples
One-dimensional array returns a length-1 tuple of indices:
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
Two-dimensional array returns a length-2 tuple of indices:
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
In either case, the resulting tuple of indices can be used directly to extract the nonzero values:
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
The output of
nonzerohas a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This can be addressed by passing a static
sizeparameter to specify the desired output shape:>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
If
sizedoes not match the true size, the result will be either truncated or padded:>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
You can specify a custom fill value for the padding using the
fill_valueargument:>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)