jax.numpy.fft.ifftn#
- jax.numpy.fft.ifftn(a, s=None, axes=None, norm=None)[source]#
Compute a multidimensional inverse discrete Fourier transform.
JAX implementation of
numpy.fft.ifftn().- Parameters:
a (ArrayLike) – input array
s (Shape | None) – sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of
aalong the specifiedaxes.axes (Sequence[int] | None) – sequence of integers, default=None. Specifies the axes along which the transform is computed. If None, computes the transform along all the axes.
norm (str | None) – string. The normalization mode. “backward”, “ortho” and “forward” are supported.
- Returns:
An array containing the multidimensional inverse discrete Fourier transform of
a.- Return type:
See also
jax.numpy.fft.fftn(): Computes a multidimensional discrete Fourier transform.jax.numpy.fft.fft(): Computes a one-dimensional discrete Fourier transform.jax.numpy.fft.ifft(): Computes a one-dimensional inverse discrete Fourier transform.
Examples
jnp.fft.ifftncomputes the transform along all the axes by default whenaxesargument isNone.>>> x = jnp.array([[1, 2, 5, 3], ... [4, 1, 2, 6], ... [5, 3, 2, 1]]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.fft.ifftn(x)) [[ 2.92+0.j 0.08-0.33j 0.25+0.j 0.08+0.33j] [-0.08+0.14j -0.04-0.03j 0. -0.29j -1.05-0.11j] [-0.08-0.14j -1.05+0.11j 0. +0.29j -0.04+0.03j]]
When
s=[3], dimension of the transform alongaxis -1will be3and dimension along other axes will be the same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.fft.ifftn(x, s=[3])) [[ 2.67+0.j -0.83-0.87j -0.83+0.87j] [ 2.33+0.j 0.83-0.29j 0.83+0.29j] [ 3.33+0.j 0.83+0.29j 0.83-0.29j]]
When
s=[2]andaxes=[0], dimension of the transform alongaxis 0will be2and dimension along other axes will be same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.fft.ifftn(x, s=[2], axes=[0])) [[ 2.5+0.j 1.5+0.j 3.5+0.j 4.5+0.j] [-1.5+0.j 0.5+0.j 1.5+0.j -1.5+0.j]]
When
s=[2, 3], shape of the transform will be(2, 3).>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.fft.ifftn(x, s=[2, 3])) [[ 2.5 +0.j 0. -0.58j 0. +0.58j] [ 0.17+0.j -0.83-0.29j -0.83+0.29j]]