jax.experimental.pallas.mosaic_gpu.Barrier#

class jax.experimental.pallas.mosaic_gpu.Barrier(*, num_arrivals=1, num_barriers=1, orders_tensor_core=False)[source]#

Describes a barrier reference.

Parameters:
  • num_arrivals (int)

  • num_barriers (int)

  • orders_tensor_core (bool)

num_arrivals#

The number of arrivals that will be recorded by this barrier.

Type:

int

num_barriers#

The number of barriers that will be created. Individual barriers can be accessed by indexing into the barrier Ref.

Type:

int

orders_tensor_core#

If False, a successfull wait from one thread does not guarantee that the TensorCore-related operations in other threads have completed. Similarly, when False any TensorCore operation in the waiting thread is allowed to begin before the wait succeeds.

Type:

bool

__init__(*, num_arrivals=1, num_barriers=1, orders_tensor_core=False)#
Parameters:
  • num_arrivals (int)

  • num_barriers (int)

  • orders_tensor_core (bool)

Return type:

None

Methods

__init__(*[, num_arrivals, num_barriers, ...])

get_ref_aval()

Attributes

num_arrivals

num_barriers

orders_tensor_core