jax.numpy.linalg.matrix_rank#
- jax.numpy.linalg.matrix_rank(M, rtol=None, *, tol=Deprecated)[source]#
Compute the rank of a matrix.
JAX implementation of
numpy.linalg.matrix_rank().The rank is calculated via the Singular Value Decomposition (SVD), and determined by the number of singular values greater than the specified tolerance.
- Parameters:
M (ArrayLike) – array of shape
(..., N, K)whose rank is to be computed.rtol (ArrayLike | None) – optional array of shape
(...)specifying the tolerance. Singular values smaller than rtol * largest_singular_value are considered to be zero. Ifrtolis None (the default), a reasonable default is chosen based the floating point precision of the input.tol (ArrayLike | DeprecatedArg | None) – deprecated alias of the
rtolargument. Will result in aDeprecationWarningif used.
- Returns:
array of shape
a.shape[-2]giving the matrix rank.- Return type:
Notes
The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the
rtolparameter or using a more specialized rank computation method in such cases.Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.matrix_rank(a) Array(2, dtype=int32)
>>> b = jnp.array([[1, 0], # Rank-deficient matrix ... [0, 0]]) >>> jnp.linalg.matrix_rank(b) Array(1, dtype=int32)