jax.experimental.checkify.check_error#
- jax.experimental.checkify.check_error(error)[source]#
Raise an Exception if
errorrepresents a failure. Functionalized bycheckify().The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None: ... err.throw() # can raise ValueError
But unlike that implementation,
check_errorcan be functionalized using thecheckify()transformation.This function is similar to
check()but with a different signature: whereascheck()takes as arguments a boolean predicate and a new error message string, this function takes anErrorvalue as argument. Bothcheck()and this function raise a Python Exception on failure (a side-effect), and thus cannot be staged out byjit(),pmap(),scan(), etc. Both also can be functionalized by usingcheckify().But unlike
check(), this function is like a direct inverse ofcheckify(): whereascheckify()takes as input a function which can raise a Python Exception and produces a new function without that effect but which produces anErrorvalue as output, thischeck_errorfunction can accept anErrorvalue as input and can produce the side-effect of raising an Exception. That is, whilecheckify()goes from functionalizable Exception effect to error value, thischeck_errorgoes from error value to functionalizable Exception effect.check_erroris useful when you want to turn checks represented by anErrorvalue (produced by functionalizingchecksviacheckify()) back into Python Exceptions.- Parameters:
error (Error) – Error to check.
- Return type:
None
For example, you might want to functionalize part of your program through checkify, stage out your functionalized code through
jit(), then re-inject your error value outside of thejit():>>> import jax >>> from jax.experimental import checkify >>> def f(x): ... checkify.check(x>0, "must be positive!") ... return x >>> def with_inner_jit(x): ... checked_f = checkify.checkify(f) ... # a checkified function can be jitted ... error, out = jax.jit(checked_f)(x) ... checkify.check_error(error) ... return out >>> _ = with_inner_jit(1) # no failed check >>> with_inner_jit(-1) Traceback (most recent call last): ... jax._src.JaxRuntimeError: must be positive! >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1)