jax.scipy.linalg.qr#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array, Array][source]#
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['full', 'economic'], pivoting: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[False] = False, check_finite: bool = True) tuple[Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: Literal[True] = True, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Literal['r'], pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array]
- jax.scipy.linalg.qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = 'full', pivoting: bool = False, check_finite: bool = True) tuple[Array] | tuple[Array, Array] | tuple[Array, Array, Array]
Compute the QR decomposition of an array
JAX implementation of
scipy.linalg.qr().The QR decomposition of a matrix A is given by
\[A = QR\]Where Q is a unitary matrix (i.e. \(Q^HQ=I\)) and R is an upper-triangular matrix.
- Parameters:
a – array of shape (…, M, N)
mode –
Computational mode. Supported values are:
"full"(default): return Q of shape(M, M)and R of shape(M, N)."r": return only R"economic": return Q of shape(M, K)and R of shape(K, N), where K = min(M, N).
pivoting – Allows the QR decomposition to be rank-revealing. If
True, compute the column-pivoted decompositionA[:, P] = Q @ R, wherePis chosen such that the diagonal ofRis non-increasing.overwrite_a – unused in JAX
lwork – unused in JAX
check_finite – unused in JAX
- Returns:
A tuple
(Q, R)or(Q, R, P), ifmodeis not"r"andpivotingis respectivelyFalseorTrue, otherwise an arrayRor tuple(R, P)if mode is"r", andpivotingis respectivelyFalseorTrue, where:Qis an orthogonal matrix of shape(..., M, M)(ifmodeis"full") or(..., M, K)(ifmodeis"economic"),Ris an upper-triangular matrix of shape(..., M, N)(ifmodeis"r"or"full") or(..., K, N)(ifmodeis"economic"),Pis an index vector of shape(..., N).
with
K = min(M, N).
Notes
At present, pivoting is only implemented on the CPU and GPU backends. For further details about the GPU implementation, see the documentation for
jax.lax.linalg.qr().
See also
jax.numpy.linalg.qr(): NumPy-style QR decomposition APIjax.lax.linalg.qr(): XLA-style QR decomposition API
Examples
Compute the QR decomposition of a matrix:
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jax.scipy.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
Check that
Qis orthonormal:>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
Reconstruct the input:
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)