jax.extend.linear_util.cache#
- jax.extend.linear_util.cache(call, *, explain=None)[source]#
Memoization decorator for functions taking a WrappedFun as first argument.
- Parameters:
call (Callable) – a Python callable that takes a WrappedFun as its first argument. The underlying transforms and params on the WrappedFun are used as part of the memoization cache key.
explain (Callable[[WrappedFun, bool, dict, tuple, float], None] | None) – a function that is invoked upon cache misses to log an explanation of the miss. Invoked with (fun, is_cache_first_use, cache, key, elapsed_sec).
- Returns:
A memoized version of
call
.