jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#
Solve the tensor equation a x = b for x.
JAX implementation of
numpy.linalg.tensorsolve().- Parameters:
- Returns:
array x such that after reordering of axes of
a,tensordot(a, x, x.ndim)is equivalent tob.- Return type:
Examples
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
Now show that
xcan be used to reconstructbusingtensordot():>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)