jax.lax.max#
- jax.lax.max(x, y)[source]#
Elementwise maximum: \(\mathrm{max}(x, y)\).
This function lowers directly to the stablehlo.maximum operation for non-complex inputs. For complex numbers, this uses a lexicographic comparison on the (real, imaginary) pairs.
- Parameters:
x (ArrayLike) – Input arrays. Must have matching dtypes. If neither is a scalar,
xandymust have the same rank and be broadcast compatible.y (ArrayLike) – Input arrays. Must have matching dtypes. If neither is a scalar,
xandymust have the same rank and be broadcast compatible.
- Returns:
An array of the same dtype as
xandycontaining the elementwise maximum.- Return type:
See also
jax.numpy.maximum(): more flexibly NumPy-style maximum.jax.lax.reduce_max(): maximum along an axis of an array.jax.lax.min(): elementwise minimum.