Handling of closed-over constants#
“Closed-over constants” are non-scalar arrays that are encountered during JAX tracing
of a function and do not have dependencies on any of the function’s arguments.
JAX operations such as jax.numpy and lax are staged out and do not create
closed-over constants.
In the following example, the arrays
a_jax_array and np.full are closed-over constants, but jnp.full
is not. We refer below to closed-over constants simply as constants.
import numpy as np
from jax import jit
from jax import numpy as jnp
a_jax_array = jnp.ones((16,), dtype=np.float32)
@jit
def f(x):
return x + a_jax_array + np.full((16,), 42.) + jnp.full((16,), 142.)
We describe below the future internal implementation details for
constants. As of July 2025, this is not yet the default implementation;
it is enabled by the environment variable JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True.
See further below for the details of the previous
implementation, including its drawbacks.
Tracing#
When JAX tracing encounters a constant that is either an argument of a JAX primitive
or a function return, it is represented as a core.Literal, and is embedded
in the Jaxpr along with the primitives that use them.
The function core.is_literalable decides which constants are turned into
core.Literal. All scalar constants are turned into core.Literal, along with
non-scalar np.ndarray and jax.Array.
Lowering#
When lowering the code to HLO we could just emit a stablehlo.constant operation
for a core.Literal, but this would have several disadvantages:
if the constant is a
jax.Array(e.g., thea_jax_arrayabove), then it is pulled from the device to the host during lowering, and it will later re-materialized on the device when the lowered module executes. This can increase the host memory usage, sometimes dramatically. Furthermore, if the constant is sharded on multiple devices this sharding is lost.large constants increase the size of the HLO, especially if the same constant is used multiple times. Also, the XLA compiler will attempt to constant-fold them, resulting in warnings and slow compilation. Furthermore, we have observed that XLA constant-folding sometimes produces slightly different numerics compared to compiled code. See also Large closed-over constants are inlined in the HLO code #29684.
Instead, during lowering we use the function core.jaxpr_const_args to scan
a Jaxpr and return a list of constants contained within, uniquified by their
id. The core.jaxpr_const_args is memoized for each Jaxpr and sub-Jaxpr
on which it is called.
All the lowered HLO functions will take one additional argument
for each unique constant appearing in the Jaxpr to which it corresponds.
These arguments, referred to as const_args,
come after the dimension variable arguments, after the
token arguments, and just before the actual array arguments.
During lowering we maintain a mapping const_lowering: dict[int, mlir.IrValues]
from the id of the constants to the HLO values for the corresponding
const args.
This mapping is stored in the mlir.LoweringRuleContext and is used
by mlir.ir_constant: when a constant is encountered, we just reuse
the existing lowering from const_lowering instead of emitting a
stablehlo.constant.
When we lower an HLO inner function (i.e., not the main function),
we call again core.jaxpr_const_args
to get the actual constants in the corresponding Jaxpr. These are
expected to be among the constants for which we have a const_lowering.
The inner function will get its own smaller set of const_args and
its own const_lowering mapping to be used when lowering the body.
E.g., the function mlir.lower_jaxpr_as_fun is one place where some
of this happens.
The function mlir.jaxpr_subcomp does not create a new HLO function,
but instead creates a block within the current function. It uses
the enclosing function’s const_lowering.
Note also that there will still be stablehlo.constant in the lowered
code, in three cases:
when the constant is a scalar; we want these constants to be available to XLA for constant folding.
when the constant did not appear in the traced program, and is hence not in the
Jaxpr. This can happen for constants that arise during lowering, e.g., the lowering of some PRNG functions include constants.when we are exporting: at the moment, we do not hoist constant args when we export because the export serialization does not currently support serialization of arrays. We use the
mlir.LoweringParameters.hoist_constants_as_argsparameter to control this.
One additional complication is that some of the internal lowering functions
need to take the argument avals and sometimes also the shardings and
layouts for the arguments. Furthermore, the avals, shardings, and layout for
all arguments, including the const args,
are used also after lowering also. Therefore, it is convenient
to compute these fairly high in the call stack, e.g., in
pxla.lower_sharding_computations, and pass them down.
For example, the functions mlir.lower_jaxpr_to_module,
pjit._pjit_cached_lower_jaxpr_to_fun, and, mlir.lower_jaxpr_to_fun
take in_avals, in_shardings, and in_layouts that
that include both the avals for const_args and for the regular args
(the ones corresponding to the Jaxpr.invars).
They also take a num_const_args argument.
Compilation and execution#
TO DO …
Previous implementation#
This describes the current way we handle closed-over constants, as
of July 2025 (as long as JAX_USE_SIMPLIFIED_CONSTANTS=False).
When JAX traces a function to a Jaxpr it collects the closed-over values
into a set of constants, and adds a corresponding set of constvars to the Jaxpr
(the actual arguments are represented by invars).
Most tracing functions, e.g., trace_to_jaxpr_dynamic,
return both the Jaxpr and the constants.
In many places in the code we use a class core.ClosedJaxpr that contains a
Jaxpr and consts corresponding to the Jaxpr.constvars.
There are several issues with ClosedJaxpr:
the lowering of the
constsinClosedJaxprresults in inlinedstablehlo.constant, with all the issues described above.JaxprandClosedJaxprare used pervasively in JAX, often with the generic namejaxprand it is not easy to tell which kind ofJaxprwe have. We have started to add type declarations, but in some places the code is written withisinstanceconditionals to work with both.Since Jaxpr and ClosedJaxpr are sometimes used as caching keys, and they are hashed by
id, we would like to memoize their construction. For example, the function pe.closed_jaxpr memoizes the construction ofClosedJaxprbut only for the case when consts is empty. This is because sometimes consts are not hashable.Handling the constants in ClosedJaxpr requires some extra care. E.g., there are places in the Mosaic lowering where we have not yet implemented the handling of ClosedJaxpr with non-empty constants (e.g. here).
When we turn closed-over constants into inputs we have to be careful during transformations with how we handle these auxiliary inputs.