jax.lax.linalg.eigh#
- jax.lax.linalg.eigh(x, *, lower=True, symmetrize_input=True, sort_eigenvalues=True, subset_by_index=None)[source]#
Eigendecomposition of a Hermitian matrix.
Computes the eigenvectors and eigenvalues of a complex Hermitian or real symmetric square matrix.
- Parameters:
x (Array) – A batch of square complex Hermitian or real symmetric matrices with shape
[..., n, n].lower (bool) – If
symmetrize_inputisFalse, describes which triangle of the input matrix to use. Ifsymmetrize_inputisFalse, only the triangle given byloweris accessed; the other triangle is ignored and not accessed.symmetrize_input (bool) – If
True, the matrix is symmetrized before the eigendecomposition by computing \(\frac{1}{2}(x + x^H)\).sort_eigenvalues (bool) – If
True, the eigenvalues will be sorted in ascending order. IfFalsethe eigenvalues are returned in an implementation-defined order.subset_by_index (tuple[int, int] | None) – Optional 2-tuple [start, end] indicating the range of indices of eigenvalues to compute. For example, is
range_select= [n-2,n], theneighcomputes the two largest eigenvalues and their eigenvectors.
- Returns:
A tuple
(v, w).vis an array with the same dtype asxsuch thatv[..., :, i]is the normalized eigenvector corresponding to eigenvaluew[..., i].wis an array with the same dtype asx(or its real counterpart if complex) with shape[..., d]containing the eigenvalues ofxin ascending order(each repeated according to its multiplicity). Ifsubset_by_indexisNonethendis equal ton. Otherwisedis equal tosubset_by_index[1] - subset_by_index[0].- Return type: