jax.experimental.pallas.triton.CompilerParams#

class jax.experimental.pallas.triton.CompilerParams(num_warps=None, num_stages=None)[source]#

Compiler parameters for Triton.

Parameters:
  • num_warps (int | None)

  • num_stages (int | None)

num_warps#

The number of warps to use for the kernel. Each warp consists of 32 threads.

Type:

int | None

num_stages#

The number of stages the compiler should use for software pipelining loops.

Type:

int | None

__init__(num_warps=None, num_stages=None)#
Parameters:
  • num_warps (int | None)

  • num_stages (int | None)

Return type:

None

Methods

__init__([num_warps, num_stages])

Attributes

BACKEND

num_stages

num_warps