jax.numpy.amax# jax.numpy.amax(a, axis=None, out=None, keepdims=False, initial=None, where=None)[source]# Alias of jax.numpy.max(). Parameters: a (ArrayLike) axis (Axis) out (None) keepdims (bool) initial (ArrayLike | None) where (ArrayLike | None) Return type: Array