XLA compiler flags#
Introduction#
This guide gives a brief overview of XLA and how XLA relates to Jax. For in-depth details please refer to XLA documentation.
XLA: The Powerhouse Behind Jax#
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax’s performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions.
Jax uses XLA’s JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime.
Configuring XLA in Jax:#
You can influence XLA’s behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook.
For the colab notebooks:
Provide flags using os.environ['XLA_FLAGS']:
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
For the python scripts:
Specify XLA_FLAGS as a part of cli command:
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
Important Notes:
Set
XLA_FLAGSbefore importing Jax or other relevant libraries. ChangingXLA_FLAGSafter backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to setXLA_FLAGSbefore executing any Jax code.Experiment with different flags to optimize performance for your specific use case.
For further information:
Complete and up to date documentation about XLA can be found in the official XLA documentation.
For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in xla/debug_options_flags.cc, and a complete list of flags could be found here.
A guide on how to use key XLA flags can be found here.
Additional reading: