jax.experimental.pallas.triton module#

Triton-specific Pallas APIs.

Classes#

CompilerParams([num_warps, num_stages])

Compiler parameters for Triton.

Functions#

atomic_and(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] &= val.

atomic_add(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] += val.

atomic_cas(ref, cmp, val)

Performs an atomic compare-and-swap of the value in the ref with the given value.

atomic_max(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] = max(x_ref_or_view[idx], val).

atomic_min(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] = min(x_ref_or_view[idx], val).

atomic_or(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] |= val.

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

Atomically exchanges the given value with the value at the given index.

atomic_xor(x_ref_or_view, idx, val, *[, mask])

Atomically computes x_ref_or_view[idx] ^= val.

approx_tanh(x)

Elementwise approximate hyperbolic tangent: \(\mathrm{tanh}(x)\).

debug_barrier()

Synchronizes all kernel executions in the grid.

elementwise_inline_asm(asm, *, args, ...)

Inline assembly applying an elementwise operation.