jax.lax.eq#
- jax.lax.eq(x, y)[source]#
Elementwise equals: \(x = y\).
This function lowers directly to the stablehlo.compare operation with
comparison_direction=EQandcompare_typeset according to the input dtype.- Parameters:
x (ArrayLike) – Input arrays. Must have matching dtypes. If neither is a scalar,
xandymust have the same number of dimensions and be broadcast compatible.y (ArrayLike) – Input arrays. Must have matching dtypes. If neither is a scalar,
xandymust have the same number of dimensions and be broadcast compatible.
- Returns:
A boolean array of shape
lax.broadcast_shapes(x.shape, y.shape)containing the elementwise equal comparison.- Return type:
See also
jax.numpy.equal(): NumPy wrapper for this API, also accessible via thex == yoperator on JAX arrays.jax.lax.ne(): elementwise not-equaljax.lax.ge(): elementwise greater-than-or-equaljax.lax.gt(): elementwise greater-thanjax.lax.le(): elementwise less-than-or-equaljax.lax.lt(): elementwise less-than