jax.stages module#

Interfaces to stages of the compiled execution process.

JAX transformations that compile just in time for execution, such as jax.jit and jax.pmap, also support a common means of explicit lowering and compilation ahead of time. This module defines types that represent the stages of this process.

For more, see the AOT walkthrough.

Classes#

class jax.stages.Wrapped(*args, **kwargs)[source]#

A function ready to be traced, lowered, and compiled.

This protocol reflects the output of functions such as jax.jit. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution.

__call__(*args, **kwargs)[source]#

Executes the wrapped function, lowering and compiling as needed.

lower(*args, **kwargs)[source]#

Lower this function explicitly for the given arguments.

This is a shortcut for self.trace(*args, **kwargs).lower().

A lowered function is staged out of Python and translated to a compiler’s input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled.

Returns:

A Lowered instance representing the lowering.

Return type:

Lowered

trace(*args, **kwargs)[source]#

Trace this function explicitly for the given arguments.

A traced function is staged out of Python and translated to a jaxpr. It is ready for lowering but not yet lowered.

Returns:

A Traced instance representing the tracing.

Return type:

Traced

class jax.stages.Traced(jaxpr, args_info, fun_name, out_tree, lower_callable, args_flat=None, arg_names=None, num_consts=0, params_out_shardings=None)[source]#

Traced form of a function specialized to argument types and values.

A traced computation is ready for lowering. This class carries the traced representation with the remaining information needed to later lower, compile, and execute it.

Parameters:
lower(*, lowering_platforms=None, _private_parameters=None)[source]#

Lower to compiler input, returning a Lowered instance.

Parameters:
  • lowering_platforms (tuple[str, ...] | None)

  • _private_parameters (mlir.LoweringParameters | None)

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[source]#

Lowering of a function specialized to argument types and values.

A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX’s various lowering paths (jit(), pmap(), etc.).

Parameters:
  • lowering (Lowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None, *, debug_info=False)[source]#

A human-readable text representation of this lowering.

Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. Use jax.export if you want reliable and portable serialization.

Parameters:
  • dialect (str | None) – Optional string specifying a lowering dialect (e.g. “stablehlo”, or “hlo”).

  • debug_info (bool) – Whether to include debugging information, e.g., source location.

Return type:

str

compile(compiler_options=None, *, device_assignment=None)[source]#

Compile, returning a corresponding Compiled instance.

Parameters:
  • compiler_options (CompilerOptions | None)

  • device_assignment (tuple[xc.Device, ...] | None)

Return type:

Compiled

compiler_ir(dialect=None)[source]#

An arbitrary object representation of this lowering.

Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations. Use jax.export if you want reliable and portable serialization.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Parameters:

dialect (str | None) – Optional string specifying a lowering dialect (e.g. “stablehlo”, or “hlo”).

Return type:

Any | None

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

property in_tree: tree_util.PyTreeDef[source]#

Tree structure of the pair (positional arguments, keyword arguments).

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[source]#

Compiled representation of a function specialized to types/values.

A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX’s various compilation paths and backends.

Parameters:
  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args, **kwargs)[source]#

Call self as a function.

as_text()[source]#

A human-readable text representation of this executable.

Intended for visualization and debugging purposes. This is not a valid nor reliable serialization.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

str | None

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

property in_tree: tree_util.PyTreeDef[source]#

Tree structure of the pair (positional arguments, keyword arguments).

memory_analysis()[source]#

A summary of estimated memory requirements.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None

runtime_executable()[source]#

An arbitrary object representation of this executable.

Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type:

Any | None