jax.numpy.triu_indices#
- jax.numpy.triu_indices(n, k=0, m=None)[source]#
Return the indices of upper triangle of an array of size
(n, m).JAX implementation of
numpy.triu_indices().- Parameters:
n (DimSize) – int. Number of rows of the array for which the indices are returned.
k (DimSize) – optional, int, default=0. Specifies the sub-diagonal on and above which the indices of upper triangle are returned.
k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.m (DimSize | None) – optional, int. Number of columns of the array for which the indices are returned. If not specified, then
m = n.
- Returns:
A tuple of two arrays containing the indices of the upper triangle, one along each axis.
- Return type:
See also
jax.numpy.tril_indices(): Returns the indices of lower triangle of an array of size(n, m).jax.numpy.triu_indices_from(): Returns the indices of upper triangle of a given array.jax.numpy.tril_indices_from(): Returns the indices of lower triangle of a given array.
Examples
If only
nis provided in input, the indices of upper triangle of an array of size(n, n)array are returned.>>> jnp.triu_indices(3) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
If both
nandmare provided in input, the indices of upper triangle of an(n, m)array are returned.>>> jnp.triu_indices(3, m=2) (Array([0, 0, 1], dtype=int32), Array([0, 1, 1], dtype=int32))
If
k = 1, the indices on and above the first sub-diagonal above the main diagonal are returned.>>> jnp.triu_indices(3, k=1) (Array([0, 0, 1], dtype=int32), Array([1, 2, 2], dtype=int32))
If
k = -1, the indices on and above the first sub-diagonal below the main diagonal are returned.>>> jnp.triu_indices(3, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32))