jax.numpy.greater#
- jax.numpy.greater(x, y, /)[source]#
Return element-wise truth value of
x > y.JAX implementation of
numpy.greater.- Parameters:
x (ArrayLike) – input array or scalar.
y (ArrayLike) – input array or scalar.
xandymust either have same shape or be broadcast compatible.
- Returns:
An array containing boolean values.
Trueif the elements ofx > y, andFalseotherwise.- Return type:
See also
jax.numpy.less(): Returns element-wise truth value ofx < y.jax.numpy.greater_equal(): Returns element-wise truth value ofx >= y.jax.numpy.less_equal(): Returns element-wise truth value ofx <= y.
Examples
Scalar inputs:
>>> jnp.greater(5, 2) Array(True, dtype=bool, weak_type=True)
Inputs with same shape:
>>> x = jnp.array([5, 9, -2]) >>> y = jnp.array([4, -1, 6]) >>> jnp.greater(x, y) Array([ True, True, False], dtype=bool)
Inputs with broadcast compatibility:
>>> x1 = jnp.array([[5, -6, 7], ... [-2, 5, 9]]) >>> y1 = jnp.array([-4, 3, 10]) >>> jnp.greater(x1, y1) Array([[ True, False, False], [ True, True, False]], dtype=bool)