jax.experimental.pallas.mosaic_gpu.TransposeTransform#
- class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[source]#
Transpose a tiled memref.
Methods
__init__
(permutation)batch
(leading_rank)Returns a transform that accepts a ref with the extra leading_rank dims.
to_gpu_transform
()to_gpu_transform_attr
()undo
(ref)Attributes
permutation