jax.experimental.pallas.mosaic_gpu.emit_pipeline#

jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0, init_carry=None)[source]#

Creates a function to emit a manual pipeline within a Pallas kernel.

Parameters:
  • body (Callable[..., T]) –

    The pipeline body function, which is called with

    • indices: Tuple of current loop indices.

    • *input_refs: SMEM refs for inputs.

    • *output_refs: SMEM refs for outputs.

    If init_carry is provided, body receives an additional argument carry – the carry from the previous iteration. It must then return the next carry value.

  • grid (pallas_core.TupleGrid) – The grid dimensions for the pipeline.

  • in_specs (Sequence[pallas_core.BlockSpec]) – A sequence of BlockSpecs for inputs.

  • out_specs (Sequence[pallas_core.BlockSpec]) – A sequence of BlockSpecs for outputs.

  • max_concurrent_steps (int) – Maximum concurrently active pipeline stages.

  • delay_release (int) – Number of steps to delay before reusing input/output references. Must be < max_concurrent_steps. Useful for hiding WGMMA latency (typically set to 1).

  • init_carry (T | None) – Optional initial carry. If provided, body handles carry-over state between iterations, and the pipeline returns the final carry.

Returns:

A function that, when called with GMEM input and output refs, executes the pipeline and returns the final carry value (if init_carry was used), otherwise it returns None.