jax.experimental.custom_dce.custom_dce#
- class jax.experimental.custom_dce.custom_dce(fun, *, static_argnums=())[source]#
Customize the DCE behavior of a JAX-transformable function.
JAX uses dead code elimination (DCE) to remove unused computations from a JAX program. This typically works transparently when the program is completely specified by known JAX operations, but opaque kernels like calls to
pallas_call()orffi_call(), for example, may cause problems.In JAX, DCE is performed when a function is staged out using
jax.jit(), so it won’t be applied when running JAX in eager mode. Similarly, thecustom_dcedecorator requires that both the decorated function and the custom DCE rule be compatible withjit().This decorator allows users to customize the DCE behavior of a function by defining a custom DCE rule. For a
custom_dcewrapped functionf(*args), the signature of the DCE rule isdce_rule(used_outs, *args)whereused_outsis a Pytree with the same structure as the output off, and each leaf is is aboolindicating which outputs should be computed. The remaining arguments*argsare the original arguments tof. The ruledce_ruleshould return a Pytree with the same structure as the original output off, but any unused outputs can be replaced withNone.For example:
>>> @jax.experimental.custom_dce.custom_dce ... def f(x, y): ... return jnp.sin(x) * y, x * jnp.sin(y) ... >>> @f.def_dce ... def f_dce_rule(used_outs, x, y): ... return ( ... jnp.sin(x) * y if used_outs[0] else None, ... x * jnp.sin(y) if used_outs[1] else None, ... )
In this example,
used_outsis atuplewith twoboolvalues, indicating which outputs are required. The DCE rule only computes the required outputs, replacing the unused outputs withNone.If the
static_argnumsargument is provided tocustom_dce, the indicated arguments are treated as static when the function is traced, and they will be moved to the front when calling the DCE rule. For example, iffuntakes 2 argumentsfun(x, y), andstatic_argnumsis(1,), then the DCE rule will be called asdce_rule(y, used_outs, x).Methods
__init__(fun, *[, static_argnums])def_dce(dce_rule)Define a custom DCE rule for this function.
Attributes
funstatic_argnumsdce_rule