jax.experimental.pallas.mosaic_gpu.Layout#
- class jax.experimental.pallas.mosaic_gpu.Layout(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
- __init__(*args, **kwds)#
Methods
reduce(axes)to_mgpu(*args, **kwargs)Attributes
WGMMA[m, n] matrix, where m % 64 == 0 == n % 8.
WGMMA_TRANSPOSEDWG_SPLATWG_STRIDEDTCGEN05TCGEN05_TRANSPOSEDTCGEN05_M64_COLLECTIVETCGEN05_TMEM_NATIVEWGMMA_ROWWGMMA_COLTCGEN05_ROWTCGEN05_COLTCGEN05_TMEM_NATIVE_ROW