jax.experimental.pallas.mosaic_gpu.async_load_tmem#

jax.experimental.pallas.mosaic_gpu.async_load_tmem(src, *, layout=None)[source]#

Performs an asynchronous load from the TMEM array.

The load operation is only partly asynchronous. The returned array can be used immediately, without any additional synchronization. However, it cannot be assumed that the read from TMEM has completed when the function returns. If you ever attempt to overwrite the read region, you should ensure that wait_load_tmem has been called before that happens. Failure to do so can result in nondeterministic data races.

For example, the following sequence of operations at the end of the kernel is valid, even though the TMEM load is never awaited:

smem_ref[...] = plgpu.async_load_tmem(tmem_ref)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, gmem_ref)
plgpu.wait_smem_to_gmem(0)

However, if the kernel was persistent and might reuse the TMEM again, the sequence should be extended with a call to wait_load_tmem.

Parameters:
  • src (_Ref) – The TMEM reference to load from.

  • layout (SomeLayout | None) – The optional layout hint to use for the resulting array.

Return type:

jax.Array