jax.lax.while_loop#
- jax.lax.while_loop(cond_fun, body_fun, init_val)[source]#
Call
body_funrepeatedly in a loop whilecond_funis True.The Haskell-like type signature in brief is
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of
while_loopare given by this Python implementation:def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
Unlike that Python version,
while_loopis a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an@jitfunction are unrolled, leading to large XLA computations.Also unlike the Python analogue, the loop-carried value
valmust hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typeain the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).Another difference from using Python-native loop constructs is that
while_loopis not reverse-mode differentiable because XLA computations require static bounds on memory requirements.Note
while_loop()compilescond_funandbody_fun, so while it can be combined withjit(), it’s usually unnecessary.- Parameters:
cond_fun (Callable[[T], BooleanNumeric]) – function of type
a -> Bool.body_fun (Callable[[T], T]) – function of type
a -> a.init_val (T) – value of type
a, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value.
- Returns:
The output from the final iteration of body_fun, of type
a.- Return type:
T