jax.experimental.pallas.GridSpec#
- class jax.experimental.pallas.GridSpec(grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=())[source]#
Encodes the grid parameters for
jax.experimental.pallas.pallas_call().See the documentation for
jax.experimental.pallas.pallas_call(), and also Grids and BlockSpecs for a more detailed description of the parameters.- Parameters:
grid (TupleGrid)
in_specs (BlockSpecTree)
out_specs (BlockSpecTree)
scratch_shapes (ScratchShapeTree)
- __init__(grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=())[source]#
- Parameters:
grid (Grid)
in_specs (BlockSpecTree)
out_specs (BlockSpecTree)
scratch_shapes (ScratchShapeTree)
Methods
__init__([grid, in_specs, out_specs, ...])Attributes
scratch_shapesgridgrid_namesin_specsout_specs