jax.experimental.pallas.mosaic_gpu.SwizzleTransform#
- class jax.experimental.pallas.mosaic_gpu.SwizzleTransform(swizzle: 'int')[source]#
- Parameters:
swizzle (int)
Methods
__init__(swizzle)batch(leading_rank)Returns a transform that accepts a ref with the extra leading_rank dims.
to_gpu_transform()to_gpu_transform_attr()undo(ref)undo_to_gpu_transform()Attributes
swizzle