jax.experimental.pallas.GridSpec#

class jax.experimental.pallas.GridSpec(grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=())[source]#

Encodes the grid parameters for jax.experimental.pallas.pallas_call().

See the documentation for jax.experimental.pallas.pallas_call(), and also Grids and BlockSpecs for a more detailed description of the parameters.

Parameters:
  • grid (TupleGrid)

  • in_specs (BlockSpecTree)

  • out_specs (BlockSpecTree)

  • scratch_shapes (ScratchShapeTree)

__init__(grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=())[source]#
Parameters:
  • grid (Grid)

  • in_specs (BlockSpecTree)

  • out_specs (BlockSpecTree)

  • scratch_shapes (ScratchShapeTree)

Methods

__init__([grid, in_specs, out_specs, ...])

Attributes

scratch_shapes

grid

grid_names

in_specs

out_specs