jax.numpy.matvec#
- jax.numpy.matvec(x1, x2, /)[source]#
Batched matrix-vector product.
JAX implementation of
numpy.matvec().- Parameters:
- Returns:
An array of shape
(..., M)containing the batched matrix-vector product.- Return type:
See also
jax.numpy.linalg.vecdot(): batched vector product.jax.numpy.vecmat(): vector-matrix product.jax.numpy.matmul(): general matrix multiplication.
Examples
Simple matrix-vector product:
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([7, 8, 9]) >>> jnp.matvec(x1, x2) Array([ 50, 122], dtype=int32)
Batched matrix-vector product:
>>> x2 = jnp.array([[7, 8, 9], ... [5, 6, 7]]) >>> jnp.matvec(x1, x2) Array([[ 50, 122], [ 38, 92]], dtype=int32)