Pallas Changelog#
This is the list of changes specific to jax.experimental.pallas.
For the overall JAX change log see here.
Unreleased#
Deprecations
pl.atomic_*APIs have been moved tojax.experimental.pallas.triton. Accessing them viajax.experimental.pallasis deprecated.
Released with jax 0.7.0#
New functionality
Added a new decorator
jax.experimental.pallas.loop()which allows to write stateless loops as functions.Added new multiple buffering and lookahead functionality to
jax.experimental.pallas.tpu.emit_pipeline(). Input buffers can now be multiple-buffered with more than 2 buffers and support a lookahead option to fetch blocks that are an arbitrary number of grid iterations ahead rather than the immediate next iterations. Additionally, pipeline state can now be held in registers to reduce scalar memory usage.
Deprecations
jax.experimental.pallas.triton.TritonCompilerParamshas been renamed tojax.experimental.pallas.triton.CompilerParams. The old name is deprecated and will be removed in a future release.jax.experimental.pallas.tpu.TPUCompilerParamsandjax.experimental.pallas.tpu.TPUMemorySpacehave been renamed tojax.experimental.pallas.tpu.CompilerParamsandjax.experimental.pallas.tpu.MemorySpace. The old names are deprecated and will be removed in a future release.
Released with jax 0.6.1#
Removals
Removed previously deprecated
jax.experimental.pallas.gpu. To use the Triton backend importjax.experimental.pallas.triton.
Changes
jax.experimental.pallas.BlockSpec()now takes in special types in addition to ints/None in theblock_shape.indexing_modehas been removed. To achieve “Unblocked”, pass apl.Element(size)intoblock_shapefor each entry that needs unblocked indexing.jax.experimental.pallas.pallas_call()now requirescompiler_paramsto be a backend-specific dataclass instead of a param to value mapping.jax.experimental.pallas.debug_check()is now supported both on TPU and Mosaic GPU. Previously, this functionality was only supported on TPU and required using the APIs fromjax.experimental.checkify. Note that debug checks are not executed unlessjax.experimental.pallas.enable_debug_checksis set.
Released with jax 0.5.0#
New functionality
Added vector support for
jax.experimental.pallas.debug_print()on TPU.
Released with jax 0.4.37#
New functionality
Added support for
DotAlgorithmPresetprecision arguments fordotlowering on Triton backend.
Released with jax 0.4.36 (December 6, 2024)#
Released with jax 0.4.35 (October 22, 2024)#
Removals
Removed previously deprecated aliases
jax.experimental.pallas.tpu.CostEstimateandjax.experimental.tpu.run_scoped(). Both are now available injax.experimental.pallas.
New functionality
Added a cost estimate tool
pl.estimate_cost()for automatically constructing a kernel cost estimate from a JAX reference function.
Released with jax 0.4.34 (October 4, 2024)#
Changes
jax.experimental.pallas.debug_print()no longer requires all arguments to be scalars. The restrictions on the arguments are backend-specific: Non-scalar arguments are currently only supported on GPU, when using Triton.jax.experimental.pallas.BlockSpecno longer supports the previously deprecated argument order, whereindex_mapcomes beforeblock_shape.
Deprecations
The
jax.experimental.pallas.gpusubmodule is deprecated to avoid ambiguite withjax.experimental.pallas.mosaic_gpu. To use the Triton backend importjax.experimental.pallas.triton.
New functionality
jax.experimental.pallas.pallas_call()now acceptsscratch_shapes, a PyTree specifying backend-specific temporary objects needed by the kernel, for example, buffers, synchronization primitives etc.checkify.check()can now be used to insert runtime asserts when pallas_call is called with thepltpu.enable_runtime_assert(True)context manager.
Released with jax 0.4.33 (September 16, 2024)#
Released with jax 0.4.32 (September 11, 2024)#
Changes
The kernel function is not allowed to close over constants. Instead, all the needed arrays must be passed as inputs, with proper block specs (#22746).
New functionality
Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map.
Released with jax 0.4.31 (July 29, 2024)#
Changes
jax.experimental.pallas.BlockSpecnow expectsblock_shapeto be passed beforeindex_map. The old argument order is deprecated and will be removed in a future release.jax.experimental.pallas.GridSpecdoes not have anymore thein_specs_tree, and theout_specs_treefields, and thein_specsandout_specstree now store the values as pytrees of BlockSpec. Previously,in_specsandout_specswere flattened (#22552).The method
compute_indexofjax.experimental.pallas.GridSpechas been removed because it is private. Similarly, theget_grid_mappingandunzip_dynamic_boundshave been removed fromBlockSpec(#22593).Fixed the interpret mode to work with BlockSpec that involve padding (#22275). Padding in interpret mode will be with NaN, to help debug out-of-bounds errors, but this behavior is not present when running in custom kernel mode, and should not be depended on.
Previously it was possible to import many APIs that are meant to be private, as
jax.experimental.pallas.pallas. This is not possible anymore.
New Functionality
Added documentation for BlockSpec: Grids and BlockSpecs.
Improved error messages for the
jax.experimental.pallas.pallas_call()API.Added lowering rules for TPU for
lax.shift_right_arithmetic(#22279) andlax.erf_inv(#22310).Added initial support for shape polymorphism for the Pallas TPU custom kernels
(#22084).Added TPU support for checkify. (#22480)
Added clearer error messages when the block sizes do not match the TPU requirements. Previously, the errors were coming from the Mosaic backend and did not have useful Python stack traces.
Added support for TPU lowering with 1D blocks, and relaxed the requirements for the block sizes with at least 2 dimensions: the last 2 dimensions must be divisible by 8 and 128 respectively, unless they span the entire corresponding array dimension. Previously, block dimensions that spanned the entire array were allowed only if the block dimensions in the last two dimensions were smaller than 8 and 128 respectively.
Released with JAX 0.4.30 (June 18, 2024)#
New Functionality
Added checkify support for
jax.experimental.pallas.pallas_call()in interpret mode (#21862).Improved support for PRNG keys for TPU kernels (#21773).