jax.numpy.interp#
- jax.numpy.interp(x, xp, fp, left=None, right=None, period=None)[source]#
One-dimensional linear interpolation.
JAX implementation of
numpy.interp().- Parameters:
x (ArrayLike) – N-dimensional array of x coordinates at which to evaluate the interpolation.
xp (ArrayLike) – one-dimensional sorted array of points to be interpolated.
fp (ArrayLike) – array of shape
xp.shapecontaining the function values associated withxp.left (ArrayLike | str | None) – specify how to handle points
x < xp[0]. Default is to returnfp[0]. Ifleftis a scalar value, it will return this value. ifleftis the string"extrapolate", then the value will be determined by linear extrapolation.leftis ignored ifperiodis specified.right (ArrayLike | str | None) – specify how to handle points
x > xp[-1]. Default is to returnfp[-1]. Ifrightis a scalar value, it will return this value. ifrightis the string"extrapolate", then the value will be determined by linear extrapolation.rightis ignored ifperiodis specified.period (ArrayLike | None) – optionally specify the period for the x coordinates, for e.g. interpolation in angular space.
- Returns:
an array of shape
x.shapecontaining the interpolated function at valuesx.- Return type:
Examples
>>> xp = jnp.arange(10) >>> fp = 2 * xp >>> x = jnp.array([0.5, 2.0, 3.5]) >>> interp(x, xp, fp) Array([1., 4., 7.], dtype=float32)
Unless otherwise specified, extrapolation will be constant:
>>> x = jnp.array([-10., 10.]) >>> interp(x, xp, fp) Array([ 0., 18.], dtype=float32)
Use
"extrapolate"mode for linear extrapolation:>>> interp(x, xp, fp, left='extrapolate', right='extrapolate') Array([-20., 20.], dtype=float32)
For periodic interpolation, specify the
period:>>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2]) >>> fp = jnp.sin(xp) >>> x = 2 * jnp.pi # note: not in input array >>> jnp.interp(x, xp, fp, period=2 * jnp.pi) Array(0., dtype=float32)