jax.numpy.mask_indices#
- jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[source]#
Return indices of a mask of an (n, n) array.
- Parameters:
n (int) – static integer array dimension.
mask_func (Callable[[ArrayLike, int], Array]) – a function that takes a shape
(n, n)array and an optional offsetk, and returns a shape(n, n)mask. Examples of functions with this signature aretriu()andtril().k (int) – a scalar value passed to
mask_func.size (int | None) – optional argument specifying the static size of the output arrays. This is passed to
nonzero()when generating the indices from the mask.
- Returns:
a tuple of indices where
mask_funcis nonzero.- Return type:
See also
jax.numpy.triu_indices(): computemask_indicesfortriu().jax.numpy.tril_indices(): computemask_indicesfortril().
Examples
Calling
mask_indiceson built-in masking functions:>>> jnp.mask_indices(3, jnp.triu) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
Calling
mask_indiceson a custom masking function:>>> def mask_func(x, k=0): ... i = jnp.arange(x.shape[0])[:, None] ... j = jnp.arange(x.shape[1]) ... return (i + 1) % (j + 1 + k) == 0 >>> mask_func(jnp.ones((3, 3))) Array([[ True, False, False], [ True, True, False], [ True, False, True]], dtype=bool) >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))