jax.numpy.linalg.qr#
- jax.numpy.linalg.qr(a, mode='reduced')[source]#
Compute the QR decomposition of an array
JAX implementation of
numpy.linalg.qr().The QR decomposition of a matrix A is given by
\[A = QR\]Where Q is a unitary matrix (i.e. \(Q^HQ=I\)) and R is an upper-triangular matrix.
- Parameters:
a (ArrayLike) – array of shape (…, M, N)
mode (str) –
Computational mode. Supported values are:
"reduced"(default): return Q of shape(..., M, K)and R of shape(..., K, N), whereK = min(M, N)."complete": return Q of shape(..., M, M)and R of shape(..., M, N)."raw": return lapack-internal representations of shape(..., M, N)and(..., K)."r": return R only.
- Returns:
A tuple
(Q, R)(ifmodeis not"r") otherwise an arrayR, where:Qis an orthogonal matrix of shape(..., M, K)(ifmodeis"reduced") or(..., M, M)(ifmodeis"complete").Ris an upper-triangular matrix of shape(..., M, N)(ifmodeis"r"or"complete") or(..., K, N)(ifmodeis"reduced")
with
K = min(M, N).- Return type:
Array | QRResult
See also
jax.scipy.linalg.qr(): SciPy-style QR decomposition APIjax.lax.linalg.qr(): XLA-style QR decomposition API
Examples
Compute the QR decomposition of a matrix:
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jnp.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
Check that
Qis orthonormal:>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
Reconstruct the input:
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)