jax.scipy.signal.correlate#
- jax.scipy.signal.correlate(in1, in2, mode='full', method='auto', precision=None)[source]#
Cross-correlation of two N-dimensional arrays.
JAX implementation of
scipy.signal.correlate().- Parameters:
in1 (Array) – left-hand input to the cross-correlation.
in2 (Array) – right-hand input to the cross-correlation. Must have
in1.ndim == in2.ndim.mode (str) –
controls the size of the output. Available operations are:
"full": (default) output the full cross-correlation of the inputs."same": return a centered portion of the"full"output which is the same size asin1."valid": return the portion of the"full"output which do not depend on padding at the array edges.
method (str) –
controls the computation method. Options are
"auto": (default) always uses the"direct"method."direct": lower tojax.lax.conv_general_dilated()."fft": compute the result via a fast Fourier transform.
precision (PrecisionLike) – Specify the precision of the computation. Refer to
jax.lax.Precisionfor a description of available values.
- Returns:
Array containing the cross-correlation result.
- Return type:
See also
jax.numpy.correlate(): 1D cross-correlationjax.scipy.signal.correlate2d(): 2D cross-correlationjax.scipy.signal.convolve(): ND convolution
Examples
A few 1D correlation examples:
>>> x = jnp.array([1, 2, 3, 2, 1]) >>> y = jnp.array([1, 3, 2])
Full 1D correlation uses implicit zero-padding at the edges:
>>> jax.scipy.signal.correlate(x, y, mode='full') Array([ 2., 7., 13., 15., 11., 5., 1.], dtype=float32)
Specifying
mode = 'same'returns a centered 1D correlation of the same size as the first input:>>> jax.scipy.signal.correlate(x, y, mode='same') Array([ 7., 13., 15., 11., 5.], dtype=float32)
Specifying
mode = 'valid'returns only the portion of 1D correlation where the two arrays fully overlap:>>> jax.scipy.signal.correlate(x, y, mode='valid') Array([13., 15., 11.], dtype=float32)