jax.numpy.linalg.tensorinv#
- jax.numpy.linalg.tensorinv(a, ind=2)[source]#
Compute the tensor inverse of an array.
JAX implementation of
numpy.linalg.tensorinv().This computes the inverse of the
tensordot()operation with the sameindvalue.- Parameters:
a (ArrayLike) – array to be inverted. Must have
prod(a.shape[:ind]) == prod(a.shape[ind:])ind (int) – positive integer specifying the number of indices in the tensor product.
- Returns:
array of shape
(*a.shape[ind:], *a.shape[:ind])containing the tensor inverse ofa.- Return type:
Examples
>>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) Array(True, dtype=bool)