jax.numpy.clip#
- jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[source]#
Clip array values to a specified range.
JAX implementation of
numpy.clip().- Parameters:
arr (ArrayLike | None) – N-dimensional array to be clipped.
min (ArrayLike | None) – optional minimum value of the clipped range; if
None(default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible witharrandmax.max (ArrayLike | None) – optional maximum value of the clipped range; if
None(default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible witharrandmin.a (ArrayLike | DeprecatedArg) – deprecated alias of the
arrargument. Will result in aDeprecationWarningif used.a_min (ArrayLike | None | DeprecatedArg) – deprecated alias of the
minargument. Will result in aDeprecationWarningif used.a_max (ArrayLike | None | DeprecatedArg) – deprecated alias of the
maxargument. Will result in aDeprecationWarningif used.
- Returns:
An array containing values from
arr, with values smaller thanminset tomin, and values larger thanmaxset tomax. Whereverminis larger thanmax, the value ofmaxis returned.- Return type:
See also
jax.numpy.minimum(): Compute the element-wise minimum value of two arrays.jax.numpy.maximum(): Compute the element-wise maximum value of two arrays.
Examples
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)