jax.numpy.fft.irfftn#
- jax.numpy.fft.irfftn(a, s=None, axes=None, norm=None)[source]#
Compute a real-valued multidimensional inverse discrete Fourier transform.
JAX implementation of
numpy.fft.irfftn().- Parameters:
a (ArrayLike) – input array.
s (Shape | None) – optional sequence of integers. Specifies the size of the output in each specified axis. If not specified, the dimension of output along axis
axes[-1]is2*(m-1),mis the size of input along axisaxes[-1]and the dimension along other axes will be the same as that of input.axes (Sequence[int] | None) – optional sequence of integers, default=None. Specifies the axes along which the transform is computed. If not specified, the transform is computed along the last
len(s)axes. If neitheraxesnorsis specified, the transform is computed along all the axes.norm (str | None) – string, default=”backward”. The normalization mode. “backward”, “ortho” and “forward” are supported.
- Returns:
A real-valued array containing the multidimensional inverse discrete Fourier transform of
awith sizesalong specifiedaxes, and the same as the input along other axes.- Return type:
See also
jax.numpy.fft.rfftn(): Computes a multidimensional discrete Fourier transform of a real-valued array.jax.numpy.fft.irfft(): Computes a real-valued one-dimensional inverse discrete Fourier transform.jax.numpy.fft.irfft2(): Computes a real-valued two-dimensional inverse discrete Fourier transform.
Examples
jnp.fft.irfftncomputes the transform along all the axes by default.>>> x = jnp.array([[[1, 3, 5], ... [2, 4, 6]], ... [[7, 9, 11], ... [8, 10, 12]]]) >>> jnp.fft.irfftn(x) Array([[[ 6.5, -1. , 0. , -1. ], [-0.5, 0. , 0. , 0. ]], [[-3. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. ]]], dtype=float32)
When
s=[3, 4], size of the transform alongaxes (-2, -1)will be(3, 4)and size along other axes will be the same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... jnp.fft.irfftn(x, s=[3, 4]) Array([[[ 2.33, -0.67, 0. , -0.67], [ 0.33, -0.74, 0. , 0.41], [ 0.33, 0.41, 0. , -0.74]], [[ 6.33, -0.67, 0. , -0.67], [ 1.33, -1.61, 0. , 1.28], [ 1.33, 1.28, 0. , -1.61]]], dtype=float32)
When
s=[3]andaxes=[0], size of the transform alongaxes 0will be3and dimension along other axes will be same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... jnp.fft.irfftn(x, s=[3], axes=[0]) Array([[[ 5., 7., 9.], [ 6., 8., 10.]], [[-2., -2., -2.], [-2., -2., -2.]], [[-2., -2., -2.], [-2., -2., -2.]]], dtype=float32)