jax.numpy.fft.fftn#
- jax.numpy.fft.fftn(a, s=None, axes=None, norm=None)[source]#
Compute a multidimensional discrete Fourier transform along given axes.
JAX implementation of
numpy.fft.fftn().- Parameters:
a (ArrayLike) – input array
s (Shape | None) – sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of
aalong the specifiedaxes.axes (Sequence[int] | None) – sequence of integers, default=None. Specifies the axes along which the transform is computed.
norm (str | None) – string. The normalization mode. “backward”, “ortho” and “forward” are supported.
- Returns:
An array containing the multidimensional discrete Fourier transform of
a.- Return type:
See also
jax.numpy.fft.fft(): Computes a one-dimensional discrete Fourier transform.jax.numpy.fft.ifft(): Computes a one-dimensional inverse discrete Fourier transform.jax.numpy.fft.ifftn(): Computes a multidimensional inverse discrete Fourier transform.
Examples
jnp.fft.fftncomputes the transform along all the axes by default whenaxesargument isNone.>>> x = jnp.array([[1, 2, 5, 6], ... [4, 1, 3, 7], ... [5, 9, 2, 1]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jnp.fft.fftn(x) Array([[ 46. +0.j , 0. +2.j , -6. +0.j , 0. -2.j ], [ -2. +1.73j, 6.12+6.73j, 0. -1.73j, -18.12-3.27j], [ -2. -1.73j, -18.12+3.27j, 0. +1.73j, 6.12-6.73j]], dtype=complex64)
When
s=[2], dimension of the transform alongaxis -1will be2and dimension along other axes will be the same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.numpy.fft.fftn(x, s=[2])) [[ 3.+0.j -1.+0.j] [ 5.+0.j 3.+0.j] [14.+0.j -4.+0.j]]
When
s=[2]andaxes=[0], dimension of the transform alongaxis 0will be2and dimension along other axes will be same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.numpy.fft.fftn(x, s=[2], axes=[0])) [[ 5.+0.j 3.+0.j 8.+0.j 13.+0.j] [-3.+0.j 1.+0.j 2.+0.j -1.+0.j]]
When
s=[2, 3], shape of the transform will be(2, 3).>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.numpy.fft.fftn(x, s=[2, 3])) [[16. +0.j -0.5+4.33j -0.5-4.33j] [ 0. +0.j -4.5+0.87j -4.5-0.87j]]
jnp.fft.ifftncan be used to reconstructxfrom the result ofjnp.fft.fftn.>>> x_fftn = jnp.fft.fftn(x) >>> jnp.allclose(x, jnp.fft.ifftn(x_fftn)) Array(True, dtype=bool)