jax.experimental.pallas.triton.CompilerParams#
- class jax.experimental.pallas.triton.CompilerParams(num_warps=None, num_stages=None)[source]#
Compiler parameters for Triton.
- num_warps#
The number of warps to use for the kernel. Each warp consists of 32 threads.
- Type:
int | None
- num_stages#
The number of stages the compiler should use for software pipelining loops.
- Type:
int | None
- __init__(num_warps=None, num_stages=None)#
Methods
__init__([num_warps, num_stages])Attributes
BACKEND