jax.experimental.pallas.mosaic_gpu.wait_load_tmem#

jax.experimental.pallas.mosaic_gpu.wait_load_tmem()[source]#

Awaits all previously asynchronous TMEM loads issued by the calling thread.

Once this function returns, the TMEM loads issued by the calling thread are guaranteed to have completed. The read TMEM regions can be safely overwritten by the calling thread, or any threads signalled through Barrier``s with ``orders_tensor_core=True.