jax.experimental.pallas.mosaic_gpu module#

Experimental GPU backend for Pallas targeting H100.

These APIs are highly unstable and can change weekly. Use at your own risk.

Classes#

Barrier(*[, num_arrivals, num_barriers, ...])

Describes a barrier reference.

BlockSpec([block_shape, index_map, ...])

CompilerParams(*[, approx_math, ...])

Mosaic GPU compiler parameters.

MemorySpace(value[, names, module, ...])

Layout(value[, names, module, qualname, ...])

SwizzleTransform(swizzle)

TilingTransform(tiling)

Represents a tiling transformation for memory refs.

TransposeTransform(permutation)

Transpose a tiled memref.

WGMMAAccumulatorRef(shape, dtype, _init)

Functions#

barrier_arrive(barrier)

Arrives at the given barrier.

barrier_wait(barrier)

Waits on the given barrier.

commit_smem()

Commits all writes to SMEM, making them visible to TMA and MMA operations.

copy_gmem_to_smem(src, dst, barrier, *[, ...])

Asynchronously copies a GMEM reference to a SMEM reference.

copy_smem_to_gmem(src, dst[, predicate, ...])

Asynchronously copies a SMEM reference to a GMEM reference.

emit_pipeline(body, *, grid[, in_specs, ...])

Creates a function to emit a manual pipeline within a Pallas kernel.

layout_cast(x, new_layout)

Casts the layout of the given array.

set_max_registers(n, *, action)

Sets the maximum number of registers owned by a warp.

wait_smem_to_gmem(n[, wait_read_only])

Waits until there are no more than n SMEM->GMEM copies in flight.

wgmma(acc, a, b)

Performs an asynchronous warp group matmul-accumulate on the given references.

wgmma_wait(n)

Waits until there is no more than n WGMMA operations in flight.

Aliases#

ACC

alias of WGMMAAccumulatorRef

GMEM

Alias of jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM.

SMEM

Alias of jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM.