jax.numpy.tri#
- jax.numpy.tri(N, M=None, k=0, dtype=None)[source]#
Return an array with ones on and below the diagonal and zeros elsewhere.
JAX implementation of
numpy.tri()- Parameters:
N (int) – int. Dimension of the rows of the returned array.
M (int | None) – optional, int. Dimension of the columns of the returned array. If not specified, then
M = N.k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the array is filled with ones.
k=0refers to main diagonal,k<0refers to sub-diagonal below the main diagonal andk>0refers to sub-diagonal above the main diagonal.dtype (DTypeLike | None) – optional, data type of the returned array. The default type is float.
- Returns:
An array of shape
(N, M)containing the lower triangle with elements below the sub-diagonal specified bykare set to one and zero elsewhere.- Return type:
See also
jax.numpy.tril(): Returns a lower triangle of an array.jax.numpy.triu(): Returns an upper triangle of an array.
Examples
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
When
Mis not equal toN:>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
when
k>0:>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
When
k<0:>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)