jax.numpy.convolve#
- jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[source]#
Convolution of two one dimensional arrays.
JAX implementation of
numpy.convolve().Convolution of one dimensional arrays is defined as:
\[c_k = \sum_j a_{k - j} v_j\]- Parameters:
a (ArrayLike) – left-hand input to the convolution. Must have
a.ndim == 1.v (ArrayLike) – right-hand input to the convolution. Must have
v.ndim == 1.mode (str) –
controls the size of the output. Available operations are:
"full": (default) output the full convolution of the inputs."same": return a centered portion of the"full"output which is the same size asa."valid": return the portion of the"full"output which do not depend on padding at the array edges.
precision (lax.PrecisionLike) – Specify the precision of the computation. Refer to
jax.lax.Precisionfor a description of available values.preferred_element_type (DTypeLike | None) – A datatype, indicating to accumulate results to and return a result with that datatype. Default is
None, which means the default accumulation type for the input types.
- Returns:
Array containing the convolved result.
- Return type:
See also
jax.scipy.signal.convolve(): ND convolutionjax.numpy.correlate(): 1D correlation
Examples
A few 1D convolution examples:
>>> x = jnp.array([1, 2, 3, 2, 1]) >>> y = jnp.array([4, 1, 2])
jax.numpy.convolve, by default, returns full convolution using implicit zero-padding at the edges:>>> jnp.convolve(x, y) Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32)
Specifying
mode = 'same'returns a centered convolution the same size as the first input:>>> jnp.convolve(x, y, mode='same') Array([ 9., 16., 15., 12., 5.], dtype=float32)
Specifying
mode = 'valid'returns only the portion where the two arrays fully overlap:>>> jnp.convolve(x, y, mode='valid') Array([16., 15., 12.], dtype=float32)
For complex-valued inputs:
>>> x1 = jnp.array([3+1j, 2, 4-3j]) >>> y1 = jnp.array([1, 2-3j, 4+5j]) >>> jnp.convolve(x1, y1) Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64)