jax.numpy.vdot#
- jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[source]#
Perform a conjugate multiplication of two 1D vectors.
JAX implementation of
numpy.vdot().- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – first input array, if not 1D it will be flattened.
b (Array | ndarray | bool | number | bool | int | float | complex) – second input array, if not 1D it will be flattened. Must have
a.size == b.size.precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – either
None(default), which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of two such values indicating precision ofaandb.preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – either
None(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
Scalar array (shape
()) containing the conjugate vector product of the inputs.- Return type:
See also
jax.numpy.vecdot(): batched vector product.jax.numpy.matmul(): general matrix multiplication.jax.lax.dot_general(): general N-dimensional batched dot product.
Examples
>>> x = jnp.array([1j, 2j, 3j]) >>> y = jnp.array([1., 2., 3.]) >>> jnp.vdot(x, y) Array(0.-14.j, dtype=complex64)
Note the difference between this and
dot(), which does not conjugate the first input when complex:>>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64)