jax.experimental.pallas.mosaic_gpu.TransposeTransform#
- class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[source]#
Transpose a tiled memref.
Methods
__init__(permutation)batch(leading_rank)Returns a transform that accepts a ref with the extra leading_rank dims.
to_gpu_transform()to_gpu_transform_attr()undo(ref)Attributes
permutation