jax.experimental.pallas.mosaic_gpu.emit_pipeline#
- jax.experimental.pallas.mosaic_gpu.emit_pipeline(body, *, grid, in_specs=(), out_specs=(), max_concurrent_steps=1, delay_release=0, init_carry=None)[source]#
Creates a function to emit a manual pipeline within a Pallas kernel.
- Parameters:
body (Callable[..., T]) –
The pipeline body function, which is called with
indices: Tuple of current loop indices.*input_refs: SMEM refs for inputs.*output_refs: SMEM refs for outputs.
If
init_carryis provided,bodyreceives an additional argumentcarry– the carry from the previous iteration. It must then return the next carry value.grid (pallas_core.TupleGrid) – The grid dimensions for the pipeline.
in_specs (Sequence[pallas_core.BlockSpec]) – A sequence of
BlockSpecs for inputs.out_specs (Sequence[pallas_core.BlockSpec]) – A sequence of
BlockSpecs for outputs.max_concurrent_steps (int) – Maximum concurrently active pipeline stages.
delay_release (int) – Number of steps to delay before reusing input/output references. Must be
< max_concurrent_steps. Useful for hiding WGMMA latency (typically set to 1).init_carry (T | None) – Optional initial carry. If provided,
bodyhandles carry-over state between iterations, and the pipeline returns the final carry.
- Returns:
A function that, when called with GMEM input and output refs, executes the pipeline and returns the final carry value (if
init_carrywas used), otherwise it returns None.