jax.scipy.signal.convolve2d#
- jax.scipy.signal.convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0, precision=None)[source]#
Convolution of two 2-dimensional arrays.
JAX implementation of
scipy.signal.convolve2d().- Parameters:
in1 (Array) – left-hand input to the convolution. Must have
in1.ndim == 2.in2 (Array) – right-hand input to the convolution. Must have
in2.ndim == 2.mode (str) –
controls the size of the output. Available operations are:
"full": (default) output the full convolution of the inputs."same": return a centered portion of the"full"output which is the same size asin1."valid": return the portion of the"full"output which do not depend on padding at the array edges.
boundary (str) – only
"fill"is supported.fillvalue (float) – only
0is supported.method –
controls the computation method. Options are
"auto": (default) always uses the"direct"method."direct": lower tojax.lax.conv_general_dilated()."fft": compute the result via a fast Fourier transform.
precision (PrecisionLike) – Specify the precision of the computation. Refer to
jax.lax.Precisionfor a description of available values.
- Returns:
Array containing the convolved result.
- Return type:
See also
jax.numpy.convolve(): 1D convolutionjax.scipy.signal.convolve(): ND convolutionjax.scipy.signal.correlate(): ND correlation
Examples
A few 2D convolution examples:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> y = jnp.array([[2, 1, 1], ... [4, 3, 4], ... [1, 3, 2]])
Full 2D convolution uses implicit zero-padding at the edges:
>>> jax.scipy.signal.convolve2d(x, y, mode='full') Array([[ 2., 5., 3., 2.], [10., 22., 17., 12.], [13., 30., 32., 20.], [ 3., 13., 18., 8.]], dtype=float32)
Specifying
mode = 'same'returns a centered 2D convolution of the same size as the first input:>>> jax.scipy.signal.convolve2d(x, y, mode='same') Array([[22., 17.], [30., 32.]], dtype=float32)
Specifying
mode = 'valid'returns only the portion of 2D convolution where the two arrays fully overlap:>>> jax.scipy.signal.convolve2d(x, y, mode='valid') Array([[22., 17.], [30., 32.]], dtype=float32)