Skip to main content
Ctrl+K
JAX  documentation - Home

Getting started

  • Installation
  • Quickstart: How to think in JAX
  • Key concepts
  • Tutorials
    • Just-in-time compilation
    • Automatic vectorization
    • Automatic differentiation
    • Introduction to debugging
    • Pseudorandom numbers
    • Working with pytrees
    • Introduction to parallel programming
    • Stateful computations
    • Control flow and logical operators with JIT
    • Advanced automatic differentiation
    • External callbacks
    • Gradient checkpointing with jax.checkpoint (jax.remat)
    • JAX Internals: primitives
    • JAX internals: The jaxpr language
  • 🔪 JAX - The Sharp Bits 🔪
  • Frequently asked questions (FAQ)

More guides/resources

  • User guides
    • Profiling computation
    • Profiling device memory
    • Debugging runtime values
      • Compiled prints and breakpoints
      • The checkify transformation
      • JAX debugging flags
    • GPU performance tips
    • Persistent compilation cache
    • Pytrees
    • Errors
    • Ahead-of-time lowering and compilation
    • Exporting and serialization
      • Exporting and serializing staged-out computations
      • Shape polymorphism
      • Interoperation with TensorFlow
    • Transfer guard
    • Pallas: a JAX kernel language
      • Pallas Quickstart
      • Software Pipelining
      • Grids and BlockSpecs
      • Pallas TPU
        • Writing TPU kernels with Pallas
        • TPU Pipelining
        • Matrix Multiplication
        • Scalar Prefetch and Block-Sparse Computation
        • Distributed Computing in Pallas for TPUs
      • Pallas:Mosaic GPU
        • Writing Mosaic GPU kernels with Pallas
        • Mosaic GPU Pipelining
      • Pallas Design Notes
        • Pallas Design
        • Pallas Async Operations
      • Pallas Changelog
    • Foreign function interface (FFI)
    • Training a simple neural network, with tensorflow/datasets data loading
    • Training a simple neural network, with PyTorch data loading
    • Autobatching for Bayesian inference
  • Advanced guides
    • Distributed arrays and automatic parallelization
    • Explicit sharding (a.k.a. “sharding in types”)
    • Manual parallelism with shard_map
    • Device-local array layout control
    • JAX Memories and Host Offloading
    • Introduction to multi-controller JAX (aka multi-process/multi-host JAX)
    • Distributed data loading
    • The Autodiff Cookbook
    • Custom derivative rules
    • Control autodiff’s saved values with jax.checkpoint (aka jax.remat)
    • Generalized convolutions in JAX
    • XLA compiler flags
  • Developer notes
    • Contributing to JAX
    • Building from source
    • Investigating a regression
    • Autodidax: JAX core from scratch
    • Autodidax2, part 1: JAX from scratch, again
    • JAX Enhancement Proposals (JEPs)
      • 263: JAX PRNG Design
      • 2026: Custom JVP/VJP rules for JAX-transformable functions
      • 4008: Custom VJP and `nondiff_argnums` update
      • 4410: Omnistaging
      • 9263: Typed keys & pluggable RNGs
      • 9407: Design of Type Promotion Semantics for JAX
      • 9419: Jax and Jaxlib versioning
      • 10657: Sequencing side-effects in JAX
      • 11830: `jax.remat` / `jax.checkpoint` new implementation
      • 12049: Type Annotation Roadmap for JAX
      • 14273: `shard_map` (`shmap`) for simple per-device code
      • 15856: `jax.extend`, an extensions module
      • 17111: Efficient transposition of `shard_map` (and other maps)
      • 18137: Scope of JAX NumPy & SciPy Wrappers
      • 25516: Effort-based versioning
      • 28661: Supporting the `__jax_array__` protocol
    • JAX Internal Implementation Notes
      • Handling of closed-over constants
  • Extension guides
    • Writing custom Jaxpr interpreters in JAX
    • jax.extend module
      • jax.extend.core module
      • jax.extend.linear_util module
      • jax.extend.mlir module
      • jax.extend.random module
    • Building on JAX
  • Notes
    • API compatibility
    • Python and NumPy version support policy
    • Asynchronous dispatch
    • Concurrency
    • GPU memory allocation
    • Rank promotion warning
    • Type promotion semantics
    • Default dtypes and the X64 flag
  • Public API: jax package
    • jax.numpy module
      • jax.numpy.fft.fft
      • jax.numpy.fft.fft2
      • jax.numpy.fft.fftfreq
      • jax.numpy.fft.fftn
      • jax.numpy.fft.fftshift
      • jax.numpy.fft.hfft
      • jax.numpy.fft.ifft
      • jax.numpy.fft.ifft2
      • jax.numpy.fft.ifftn
      • jax.numpy.fft.ifftshift
      • jax.numpy.fft.ihfft
      • jax.numpy.fft.irfft
      • jax.numpy.fft.irfft2
      • jax.numpy.fft.irfftn
      • jax.numpy.fft.rfft
      • jax.numpy.fft.rfft2
      • jax.numpy.fft.rfftfreq
      • jax.numpy.fft.rfftn
    • jax.scipy module
      • jax.scipy.stats.bernoulli.logpmf
      • jax.scipy.stats.bernoulli.pmf
      • jax.scipy.stats.bernoulli.cdf
      • jax.scipy.stats.bernoulli.ppf
    • jax.lax module
    • jax.random module
    • jax.sharding module
    • jax.debug module
    • jax.dlpack module
    • jax.distributed module
    • jax.dtypes module
    • jax.ffi module
    • jax.flatten_util module
    • jax.image module
    • jax.nn module
      • jax.nn.initializers module
    • jax.ops module
    • jax.profiler module
    • jax.stages module
    • jax.test_util module
    • jax.tree module
    • jax.tree_util module
    • jax.typing module
    • jax.export module
    • jax.extend module
      • jax.extend.core module
      • jax.extend.linear_util module
      • jax.extend.mlir module
      • jax.extend.random module
    • jax.example_libraries module
      • jax.example_libraries.optimizers module
      • jax.example_libraries.stax module
    • jax.experimental module
      • jax.experimental.checkify module
      • jax.experimental.compilation_cache module
      • jax.experimental.custom_dce module
      • jax.experimental.custom_partitioning module
      • jax.experimental.jet module
      • jax.experimental.key_reuse module
      • jax.experimental.mesh_utils module
      • jax.experimental.multihost_utils module
      • jax.experimental.pallas module
        • jax.experimental.pallas.mosaic_gpu module
        • jax.experimental.pallas.triton module
        • jax.experimental.pallas.tpu module
      • jax.experimental.pjit module
      • jax.experimental.serialize_executable module
      • jax.experimental.shard_map module
      • jax.experimental.sparse module
        • jax.experimental.sparse.BCOO
        • jax.experimental.sparse.bcoo_broadcast_in_dim
        • jax.experimental.sparse.bcoo_concatenate
        • jax.experimental.sparse.bcoo_dot_general
        • jax.experimental.sparse.bcoo_dot_general_sampled
        • jax.experimental.sparse.bcoo_dynamic_slice
        • jax.experimental.sparse.bcoo_extract
        • jax.experimental.sparse.bcoo_fromdense
        • jax.experimental.sparse.bcoo_gather
        • jax.experimental.sparse.bcoo_multiply_dense
        • jax.experimental.sparse.bcoo_multiply_sparse
        • jax.experimental.sparse.bcoo_update_layout
        • jax.experimental.sparse.bcoo_reduce_sum
        • jax.experimental.sparse.bcoo_reshape
        • jax.experimental.sparse.bcoo_slice
        • jax.experimental.sparse.bcoo_sort_indices
        • jax.experimental.sparse.bcoo_squeeze
        • jax.experimental.sparse.bcoo_sum_duplicates
        • jax.experimental.sparse.bcoo_todense
        • jax.experimental.sparse.bcoo_transpose
    • jax.lib module
    • jax.Array.addressable_shards
    • jax.Array.all
    • jax.Array.any
    • jax.Array.argmax
    • jax.Array.argmin
    • jax.Array.argpartition
    • jax.Array.argsort
    • jax.Array.astype
    • jax.Array.at
    • jax.Array.choose
    • jax.Array.clip
    • jax.Array.compress
    • jax.Array.committed
    • jax.Array.conj
    • jax.Array.conjugate
    • jax.Array.copy
    • jax.Array.copy_to_host_async
    • jax.Array.cumprod
    • jax.Array.cumsum
    • jax.Array.device
    • jax.Array.diagonal
    • jax.Array.dot
    • jax.Array.dtype
    • jax.Array.flat
    • jax.Array.flatten
    • jax.Array.global_shards
    • jax.Array.imag
    • jax.Array.is_fully_addressable
    • jax.Array.is_fully_replicated
    • jax.Array.item
    • jax.Array.itemsize
    • jax.Array.max
    • jax.Array.mean
    • jax.Array.min
    • jax.Array.nbytes
    • jax.Array.ndim
    • jax.Array.nonzero
    • jax.Array.prod
    • jax.Array.ptp
    • jax.Array.ravel
    • jax.Array.real
    • jax.Array.repeat
    • jax.Array.reshape
    • jax.Array.round
    • jax.Array.searchsorted
    • jax.Array.shape
    • jax.Array.sharding
    • jax.Array.size
    • jax.Array.sort
    • jax.Array.squeeze
    • jax.Array.std
    • jax.Array.sum
    • jax.Array.swapaxes
    • jax.Array.take
    • jax.Array.to_device
    • jax.Array.trace
    • jax.Array.transpose
    • jax.Array.var
    • jax.Array.view
    • jax.Array.T
    • jax.Array.mT
  • About the project
  • Change log
  • Glossary of terms
  • Configuration Options
  • Parallel computation
  • .md

Parallel computation

Parallel computation#

Note

This is a placeholder for a section in the new JAX tutorials draft.

For the time being, you may find some related content in the old documentation:

  • Introduction to multi-controller JAX (aka multi-process/multi-host JAX)

  • Distributed arrays and automatic parallelization

By The JAX authors

© Copyright 2024, The JAX Authors.