jax.lax.select#
- jax.lax.select(pred, on_true, on_false)[source]#
Selects between two branches based on a boolean predicate.
Wraps XLA’s Select operator.
In general
select()leads to evaluation of both branches, although the compiler may elide computations if possible. For a similar function that usually evaluates only a single branch, seecond().- Parameters:
pred (ArrayLike) – boolean array
on_true (ArrayLike) – array containing entries to return where
predis True. Must have the same shape aspred, and the same shape and dtype ason_false.on_false (ArrayLike) – array containing entries to return where
predis False. Must have the same shape aspred, and the same shape and dtype ason_true.
- Returns:
array with same shape and dtype as
on_trueandon_false.- Return type:
result