jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem#
- jax.experimental.pallas.mosaic_gpu.copy_gmem_to_smem(src, dst, barrier, *, collective_axes=None, partitioned_axis=None)[source]#
Asynchronously copies a GMEM reference to a SMEM reference.
If collective_axes is specified, this performs a multicast copy where all CUDA blocks that share the same index along the collective axis receive a copy of the same block of data loaded from dst to src.
If both collective_axes and partitioned_axis are specified, this will perform a partitioned collective copy where each block in the cluster will receive a tile of transfer_size // cluster_size data from the src Ref. For example, if src has a shape of (256, 256) and a partitioned copy is performed along axis 0 with cluster size 2, then the first block will receive src[0:128, :] and the second will receive src[128:256, :]. NOTE: Only the first block in the cluster will arrive on the barrier, and an additional cluster barrier is necessary to ensure that all blocks in the cluster have finished the copy.
- Parameters:
src (_Ref) – The source Ref. Must be in GMEM.
dst (_Ref) – The destination Ref. Must be in SMEM.
barrier (_Ref) – The barrier to use for tracking completion of the copy.
collective_axes (str | tuple[str, ...] | None) – The collective axes to use for the copy.
partitioned_axis (int | None) – Indicates which array axis along the src/dst Refs to partition across during a partitioned collective copy. Requires collective_axes to also be specified.
- Return type:
None
See also
jax.experimental.mosaic.gpu.barrier_arrive()jax.experimental.mosaic.gpu.barrier_wait()