jax.lax.reduce_max#
- jax.lax.reduce_max(operand, axes)[source]#
Compute the maximum of elements over one or more array axes.
- Parameters:
operand (ArrayLike) – array over which to compute maximum.
axes (Sequence[int]) – sequence of zero or more unique integers specifying the axes over which to reduce. Each entry must satisfy
0 <= axis < operand.ndim.
- Returns:
An array of the same dtype as
operand, with shape corresponding to the dimensions ofoperand.shapewithaxesremoved.- Return type:
See also
jax.numpy.max(): more flexible NumPy-style max-reduction API, built aroundjax.lax.reduce_max().Other low-level
jax.laxreduction operators:jax.lax.reduce_sum(),jax.lax.reduce_prod(),jax.lax.reduce_min(),jax.lax.reduce_and(),jax.lax.reduce_or(),jax.lax.reduce_xor().