JAX
Data & MLComposable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Release History
jax-v0.10.1Breaking11 featuresThis release introduces several new linear algebra functions in `jax.scipy.linalg` and updates RNG API handling by moving related functionality to dtypes. It also deprecates positional arguments for certain array creation parameters and modifies the mesh context manager usage.
jax-v0.10.0Breaking4 fixes6 featuresThis release introduces new features like PyTorch-compatible cubic resizing and improved LAPACK parallelization, alongside significant breaking changes related to the removal of the C++ pmap infrastructure and updates to array stacking functions.
jax-v0.9.2BreakingJAX 0.9.2 updates the internal type structure of `TypedNdArray` and changes the execution context for `jnp.arange` with a step argument, potentially affecting float precision.
jax-v0.9.1Breaking1 fix1 featureThis release improves strictness in `jax.shard_map` explicit mode by enforcing PartitionSpec matching and updates JAX tracer type reporting. A new debug configuration for the compilation cache was also introduced.
jax-v0.9.0.1JAX v0.9.0.1 is a patch release identical to v0.9.0, incorporating fixes from four specific XLA pull requests.
jax-v0.8.32 fixesJAX v0.8.3 is released, incorporating two specific bug fixes that were missing from the previous v0.8.2 release.
jax-v0.9.0Breaking1 fix1 featureThis release introduces `jax.thread_guard` for multi-threaded JAX environments and updates the serialization format for `jax.export` to support explicit sharding, which tightens mesh matching requirements. Several configuration states and deprecated functions like `jax.numpy.fix` have been removed or deprecated.
jax-v0.8.2BreakingThis release deprecates several core and interpreter symbols, removes `jax.experimental.si_vjp`, and changes `Tracer` inheritance, requiring code updates to use `jax.vjp` and the new `pcast` API.
jax-v0.8.11 fix3 featuresJAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.
jax-v0.8.0Breaking5 featuresJAX introduces several breaking changes, including a new default implementation for `jax.pmap` and removal of many deprecated APIs, while adding new features such as namedtuple returns for eig and enhanced dlpack support.
jax-v0.7.2Breaking2 fixesJAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.
jax-v0.7.1Breaking5 featuresJAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.
jax-v0.7.0Breaking2 featuresJAX 0.7 introduces Shardy as the default execution model, updates autodiff to direct linearization, raises the minimum Python version to 3.11, and deprecates or removes several legacy APIs.
jax-v0.6.2Breaking1 featureThis release introduces the new `jax.tree.broadcast` helper and raises the minimum required versions of NumPy and SciPy.
jax-v0.6.1Breaking1 featureThis release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.
jax-v0.6.0Breaking4 featuresThis release removes several legacy tracing and configuration options, raises the minimum CUDA/CuDNN versions, updates package extras syntax, and deprecates many old APIs while introducing stricter `jax.jit` calling conventions.
jax-v0.5.32 featuresThis release introduces new options for slicing functions and categorical sampling in JAX, improving code size and adding support for sampling without replacement.
jax-v0.5.21 fixPatch 0.5.1 fixes TPU metric logging and the `tpu-info` command.
jax-v0.5.12 fixes3 featuresThis release adds experimental custom DCE support, new low‑level reduction ops, column‑pivoting QR, updates CPU collective defaults, removes the libtpu‑nightly dependency, deprecates internal functions requiring debug info, and includes TPU runtime and compilation cache fixes.
jax-v0.5.0Breaking2 fixes3 featuresJAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.
Common Errors
NotImplementedError4 reportsThe "NotImplementedError" in JAX usually arises when a function or operation is called for a data type, or a specific configuration (e.g., sparse arrays, custom autodiff rules, pallas kernels in sharded contexts), that hasn't been explicitly implemented in the JAX backend being used. To fix it, either implement the missing functionality for the given type/configuration, or explicitly use JAX primitives and operations defined for the supported data types, rewriting the code avoiding unsupported operations. Consider using `jax.experimental` libraries to work with experimental features or contributing the missing feature to the JAX project.
JaxRuntimeError4 reportsJaxRuntimeError often indicates an internal error within JAX's XLA compiler or underlying CUDA/cuDNN/cuSOLVER libraries. Try updating JAX, CUDA, and cuDNN to the latest compatible versions. If the error persists, consider disabling XLA autotuning via `jax.config.update("jax_disable_jit", True)` as a temporary workaround and report the issue with a minimal reproducible example to the JAX issue tracker.
XlaRuntimeError2 reportsXlaRuntimeError in JAX often arises from mismatches between the expected and actual data shapes or dtypes within XLA computations, especially during operations like FFTs or when specifying memory layouts. To resolve this, carefully inspect the input shapes, dtypes, and sharding specifications to ensure they align with XLA's requirements and constraints of the targeted device. Explicit dtype casting (e.g., using `jnp.asarray(x, dtype=jnp.complex64)`) or adjusting sharding annotations can often rectify these discrepancies.
KeyReuseError2 reportsThe "KeyReuseError" in jax arises when the same PRNGKey is used multiple times within a jax transformation (like vmap or pmap) without being split, which leads to non-deterministic or potentially correlated random numbers. To fix this, ensure that each PRNGKey used within a transformed function is uniquely split using `jax.random.split` before each use, generating new keys for subsequent operations, effectively preventing key reuse. Always split your PRNGKey before passing it to JAX functions that use randomness within loops or vmapped code.
ShardingTypeError1 reportThe "ShardingTypeError" often arises when JAX cannot infer the sharding for an operation, especially after type conversions or views that alter array layouts unexpectedly. To resolve this, explicitly specify a sharding rule (e.g., using `jax.Array(..., sharding=...)` or `jax.device_put(..., device=...)`) for the resulting array after the problematic operation, ensuring JAX knows how to distribute it across devices. Alternatively, revise the code to avoid the implicit type conversion or view if the intended sharding is ambiguous.
TypeError1 reportThis error usually arises when JAX's automatic differentiation (autodiff) encounters a `ShapedArray` unexpectedly during the backward pass of a `fori_loop`, often triggered by operations incompatible with sharded arrays within the loop. Resolve this by explicitly materializing the `ShapedArray` into a concrete array using `jax.device_put` or `jax.block_until_ready` before its used in an operation that triggers the error, or restructuring the calculation to avoid the incompatible operation within the loop's backward pass. Review the sharding annotations to ensure data is properly distributed and materialized when needed.
Related Data & ML Packages
An Open Source Machine Learning Framework for Everyone
🤗 Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
scikit-learn: machine learning in Python
Flexible and powerful data analysis / manipulation library for Python, providing labeled data structures similar to R data.frame objects, statistical functions, and much more
Streamlit — A faster way to build and share data apps.
Subscribe to Updates
Get notified when new versions are released