Handling of closed-over constants#

“Closed-over constants” are non-scalar arrays that are encountered during JAX tracing of a function and do not have dependencies on any of the function’s arguments. JAX operations such as jax.numpy and lax are staged out and do not create closed-over constants. In the following example, the arrays a_jax_array and np.full are closed-over constants, but jnp.full is not. We refer below to closed-over constants simply as constants.

import numpy as np
from jax import jit
from jax import numpy as jnp

a_jax_array = jnp.ones((16,), dtype=np.float32)

@jit
def f(x):
  return x + a_jax_array + np.full((16,), 42.) + jnp.full((16,), 142.)

We describe below the future internal implementation details for constants. As of July 2025, this is not yet the default implementation; it is enabled by the environment variable JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True. See further below for the details of the previous implementation, including its drawbacks.

Tracing#

When JAX tracing encounters a constant that is either an argument of a JAX primitive or a function return, it is represented as a core.Literal, and is embedded in the Jaxpr along with the primitives that use them. The function core.is_literalable decides which constants are turned into core.Literal. All scalar constants are turned into core.Literal, along with non-scalar np.ndarray and jax.Array.

Lowering#

When lowering the code to HLO we could just emit a stablehlo.constant operation for a core.Literal, but this would have several disadvantages:

  • if the constant is a jax.Array (e.g., the a_jax_array above), then it is pulled from the device to the host during lowering, and it will later re-materialized on the device when the lowered module executes. This can increase the host memory usage, sometimes dramatically. Furthermore, if the constant is sharded on multiple devices this sharding is lost.

  • large constants increase the size of the HLO, especially if the same constant is used multiple times. Also, the XLA compiler will attempt to constant-fold them, resulting in warnings and slow compilation. Furthermore, we have observed that XLA constant-folding sometimes produces slightly different numerics compared to compiled code. See also Large closed-over constants are inlined in the HLO code #29684.

Instead, during lowering we use the function core.jaxpr_const_args to scan a Jaxpr and return a list of constants contained within, uniquified by their id. The core.jaxpr_const_args is memoized for each Jaxpr and sub-Jaxpr on which it is called.

All the lowered HLO functions will take one additional argument for each unique constant appearing in the Jaxpr to which it corresponds. These arguments, referred to as const_args, come after the dimension variable arguments, after the token arguments, and just before the actual array arguments. During lowering we maintain a mapping const_lowering: dict[int, mlir.IrValues] from the id of the constants to the HLO values for the corresponding const args. This mapping is stored in the mlir.LoweringRuleContext and is used by mlir.ir_constant: when a constant is encountered, we just reuse the existing lowering from const_lowering instead of emitting a stablehlo.constant.

When we lower an HLO inner function (i.e., not the main function), we call again core.jaxpr_const_args to get the actual constants in the corresponding Jaxpr. These are expected to be among the constants for which we have a const_lowering. The inner function will get its own smaller set of const_args and its own const_lowering mapping to be used when lowering the body. E.g., the function mlir.lower_jaxpr_as_fun is one place where some of this happens.

The function mlir.jaxpr_subcomp does not create a new HLO function, but instead creates a block within the current function. It uses the enclosing function’s const_lowering.

Note also that there will still be stablehlo.constant in the lowered code, in three cases:

  • when the constant is a scalar; we want these constants to be available to XLA for constant folding.

  • when the constant did not appear in the traced program, and is hence not in the Jaxpr. This can happen for constants that arise during lowering, e.g., the lowering of some PRNG functions include constants.

  • when we are exporting: at the moment, we do not hoist constant args when we export because the export serialization does not currently support serialization of arrays. We use the mlir.LoweringParameters.hoist_constants_as_args parameter to control this.

One additional complication is that some of the internal lowering functions need to take the argument avals and sometimes also the shardings and layouts for the arguments. Furthermore, the avals, shardings, and layout for all arguments, including the const args, are used also after lowering also. Therefore, it is convenient to compute these fairly high in the call stack, e.g., in pxla.lower_sharding_computations, and pass them down.

For example, the functions mlir.lower_jaxpr_to_module, pjit._pjit_cached_lower_jaxpr_to_fun, and, mlir.lower_jaxpr_to_fun take in_avals, in_shardings, and in_layouts that that include both the avals for const_args and for the regular args (the ones corresponding to the Jaxpr.invars). They also take a num_const_args argument.

Compilation and execution#

TO DO …

Previous implementation#

This describes the current way we handle closed-over constants, as of July 2025 (as long as JAX_USE_SIMPLIFIED_CONSTANTS=False).

When JAX traces a function to a Jaxpr it collects the closed-over values into a set of constants, and adds a corresponding set of constvars to the Jaxpr (the actual arguments are represented by invars). Most tracing functions, e.g., trace_to_jaxpr_dynamic, return both the Jaxpr and the constants.

In many places in the code we use a class core.ClosedJaxpr that contains a Jaxpr and consts corresponding to the Jaxpr.constvars.

There are several issues with ClosedJaxpr:

  • the lowering of the consts in ClosedJaxpr results in inlined stablehlo.constant, with all the issues described above.

  • Jaxpr and ClosedJaxpr are used pervasively in JAX, often with the generic name jaxpr and it is not easy to tell which kind of Jaxpr we have. We have started to add type declarations, but in some places the code is written with isinstance conditionals to work with both.

  • Since Jaxpr and ClosedJaxpr are sometimes used as caching keys, and they are hashed by id, we would like to memoize their construction. For example, the function pe.closed_jaxpr memoizes the construction of ClosedJaxpr but only for the case when consts is empty. This is because sometimes consts are not hashable.

  • Handling the constants in ClosedJaxpr requires some extra care. E.g., there are places in the Mosaic lowering where we have not yet implemented the handling of ClosedJaxpr with non-empty constants (e.g. here).

  • When we turn closed-over constants into inputs we have to be careful during transformations with how we handle these auxiliary inputs.