jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[source]#

Create a square or rectangular identity matrix

JAX implementation of numpy.eye().

Parameters:
  • N (DimSize) – integer specifying the first dimension of the array.

  • M (DimSize | None) – optional integer specifying the second dimension of the array; defaults to the same value as N.

  • k (int | ArrayLike) – optional integer specifying the offset of the diagonal. Use positive values for upper diagonals, and negative values for lower diagonals. Default is zero.

  • dtype (DTypeLike | None) – optional dtype; defaults to floating point.

  • device (xc.Device | Sharding | None) – optional Device or Sharding to which the created array will be committed.

Returns:

Identity array of shape (N, M), or (N, N) if M is not specified.

Return type:

Array

See also

jax.numpy.identity(): Simpler API for generating square identity matrices.

Examples

A simple 3x3 identity matrix:

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

Integer identity matrices with offset diagonals:

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

Non-square identity matrix:

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)