jax.scipy.linalg.lu_solve#
- jax.scipy.linalg.lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[source]#
Solve a linear system using an LU factorization
JAX implementation of
scipy.linalg.lu_solve(). Uses the output ofjax.scipy.linalg.lu_factor().- Parameters:
lu_and_piv (tuple[Array, ArrayLike]) –
(lu, piv), output oflu_factor().luis an array of shape(..., M, N), containingLin its lower triangle andUin its upper.pivis an array of shape(..., K), withK = min(M, N), which encodes the pivots.b (ArrayLike) – right-hand-side of linear system. Must have shape
(..., M)trans (int) –
type of system to solve. Options are:
0: \(A x = b\)1: \(A^Tx = b\)2: \(A^Hx = b\)
overwrite_b (bool) – unused by JAX
check_finite (bool) – unused by JAX
- Returns:
Array of shape
(..., N)representing the solution of the linear system.- Return type:
Examples
Solving a small linear system via LU factorization:
>>> a = jnp.array([[2., 1.], ... [1., 2.]])
Compute the lu factorization via
lu_factor(), and use it to solve a linear equation vialu_solve().>>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(a @ y, b) Array(True, dtype=bool)