jax.experimental.pallas.pallas_call#
- jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, scratch_shapes=(), input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None, backend=None, metadata=None)[source]#
Invokes a Pallas kernel on some inputs.
See Pallas Quickstart.
- Parameters:
kernel (Callable[..., None]) – the kernel function, that receives a Ref for each input and output. The shape of the Refs are given by the
block_shapein the correspondingin_specsandout_specs.out_shape (Any) – a PyTree of
jax.ShapeDtypeStructdescribing the shape and dtypes of the outputs.grid_spec (GridSpec | None) – An alternative way to specify
grid,in_specs,out_specsandscratch_shapes. If given, those other parameters must not be also given.grid (TupleGrid) – the iteration space, as a tuple of integers. The kernel is executed as many times as
prod(grid). See details at grid, a.k.a. kernels in a loop.in_specs (BlockSpecTree) – a PyTree of
jax.experimental.pallas.BlockSpecwith a structure matching that of the positional arguments. The default value forin_specsspecifies the whole array for all inputs, e.g., aspl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim). See details at BlockSpec, a.k.a. how to chunk up inputs.out_specs (BlockSpecTree) – a PyTree of
jax.experimental.pallas.BlockSpecwith a structure matching that of the outputs. The default value forout_specsspecifies the whole array, e.g., aspl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim). See details at BlockSpec, a.k.a. how to chunk up inputs.scratch_shapes (ScratchShapeTree) – a PyTree of backend-specific temporary objects required by the kernel, such as temporary buffers, synchronization primitives, etc.
input_output_aliases (Mapping[int, int]) – a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs.
debug (bool) – if True, Pallas prints various intermediate forms of the kernel as it is being processed.
interpret (Any) – runs the
pallas_callas ajax.jitof a scan over the grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. This is useful for debugging.name (str | None) – if present, specifies the name to use for this kernel call in debugging and error messages. To this name we append the file and line where the kernel function is defined, .e.g: {name} for kernel function {kernel_name} at {file}:{line}. If missing, then we use {kernel_name} at {file}:{line}.
compiler_params (Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None) – Optional compiler parameters. The value should either be a backend-specific dataclass (
jax.experimental.pallas.tpu.CompilerParams,jax.experimental.pallas.triton.CompilerParams,jax.experimental.pallas.mosaic_gpu.CompilerParams) or a dict mapping backend name to the corresponding platform-specific dataclass.backend (Backend | None) – Optional string literal one of
"mosaic_tpu","triton"or"mosaic_gpu"determining the backend to be used. None means let Pallas decide.metadata (dict[str, str] | None) – Optional dictionary of information about the kernel that will be serialized as JSON in the HLO. Can be used for debugging and analysis.
cost_estimate (CostEstimate | None)
- Returns:
A function that can be called on a number of positional array arguments to invoke the Pallas kernel.
- Return type:
Callable[…, Any]