Getting Started with JAX#
Welcome to JAX! The JAX documentation contains a number of useful resources for getting started. Quickstart: How to think in JAX is the easiest place to jump in and get an overview of the JAX project, its execution model, and differences with NumPy.
If you’re starting to explore JAX, you might also find the following resources helpful:
Key concepts introduces the key concepts of JAX, such as transformations, tracing, jaxprs and pytrees.
🔪 JAX - The Sharp Bits 🔪 lists some of JAX’s sharp corners.
Frequently asked questions (FAQ) answers some frequent JAX questions.
Tutorials#
If you’re ready to explore JAX more deeply, the JAX tutorials go into much more detail:
- Tutorials
- Just-in-time compilation
- Automatic vectorization
- Automatic differentiation
- Introduction to debugging
- Pseudorandom numbers
- Working with pytrees
- Introduction to parallel programming
- Stateful computations
- Control flow and logical operators with JIT
- Advanced automatic differentiation
- External callbacks
- Gradient checkpointing with
jax.checkpoint(jax.remat) - JAX Internals: primitives
- JAX internals: The jaxpr language
If you prefer a video introduction here is one from JAX contributor Jake VanderPlas:
Building on JAX#
JAX provides the core numerical computing primitives for a number of tools developed by the larger community. For example, if you’re interested in using JAX for training neural networks, two well-supported options are Flax and Haiku.
For a community-curated list of JAX-related projects across a wide set of domains, check out Awesome JAX.
Finding Help#
If you have questions about JAX, we’d love to answer them! Two good places to get your questions answered are: