Change8

JAX

Data & ML

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Latest: jax-v0.10.120 releases14 breaking changes9 common errorsView on GitHub

Release History

jax-v0.10.1Breaking11 features
May 20, 2026

This release introduces several new linear algebra functions in `jax.scipy.linalg` and updates RNG API handling by moving related functionality to dtypes. It also deprecates positional arguments for certain array creation parameters and modifies the mesh context manager usage.

jax-v0.10.0Breaking4 fixes6 features
Apr 16, 2026

This release introduces new features like PyTorch-compatible cubic resizing and improved LAPACK parallelization, alongside significant breaking changes related to the removal of the C++ pmap infrastructure and updates to array stacking functions.

jax-v0.9.2Breaking
Mar 18, 2026

JAX 0.9.2 updates the internal type structure of `TypedNdArray` and changes the execution context for `jnp.arange` with a step argument, potentially affecting float precision.

jax-v0.9.1Breaking1 fix1 feature
Mar 2, 2026

This release improves strictness in `jax.shard_map` explicit mode by enforcing PartitionSpec matching and updates JAX tracer type reporting. A new debug configuration for the compilation cache was also introduced.

jax-v0.9.0.1
Feb 5, 2026

JAX v0.9.0.1 is a patch release identical to v0.9.0, incorporating fixes from four specific XLA pull requests.

jax-v0.8.32 fixes
Jan 29, 2026

JAX v0.8.3 is released, incorporating two specific bug fixes that were missing from the previous v0.8.2 release.

jax-v0.9.0Breaking1 fix1 feature
Jan 20, 2026

This release introduces `jax.thread_guard` for multi-threaded JAX environments and updates the serialization format for `jax.export` to support explicit sharding, which tightens mesh matching requirements. Several configuration states and deprecated functions like `jax.numpy.fix` have been removed or deprecated.

jax-v0.8.2Breaking
Dec 18, 2025

This release deprecates several core and interpreter symbols, removes `jax.experimental.si_vjp`, and changes `Tracer` inheritance, requiring code updates to use `jax.vjp` and the new `pcast` API.

jax-v0.8.11 fix3 features
Nov 18, 2025

JAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.

jax-v0.8.0Breaking5 features
Oct 15, 2025

JAX introduces several breaking changes, including a new default implementation for `jax.pmap` and removal of many deprecated APIs, while adding new features such as namedtuple returns for eig and enhanced dlpack support.

jax-v0.7.2Breaking2 fixes
Sep 16, 2025

JAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.

jax-v0.7.1Breaking5 features
Aug 20, 2025

JAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.

jax-v0.7.0Breaking2 features
Jul 22, 2025

JAX 0.7 introduces Shardy as the default execution model, updates autodiff to direct linearization, raises the minimum Python version to 3.11, and deprecates or removes several legacy APIs.

jax-v0.6.2Breaking1 feature
Jun 17, 2025

This release introduces the new `jax.tree.broadcast` helper and raises the minimum required versions of NumPy and SciPy.

jax-v0.6.1Breaking1 feature
May 21, 2025

This release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.

jax-v0.6.0Breaking4 features
Apr 17, 2025

This release removes several legacy tracing and configuration options, raises the minimum CUDA/CuDNN versions, updates package extras syntax, and deprecates many old APIs while introducing stricter `jax.jit` calling conventions.

jax-v0.5.32 features
Mar 19, 2025

This release introduces new options for slicing functions and categorical sampling in JAX, improving code size and adding support for sampling without replacement.

jax-v0.5.21 fix
Mar 5, 2025

Patch 0.5.1 fixes TPU metric logging and the `tpu-info` command.

jax-v0.5.12 fixes3 features
Feb 24, 2025

This release adds experimental custom DCE support, new low‑level reduction ops, column‑pivoting QR, updates CPU collective defaults, removes the libtpu‑nightly dependency, deprecates internal functions requiring debug info, and includes TPU runtime and compilation cache fixes.

jax-v0.5.0Breaking2 fixes3 features
Jan 17, 2025

JAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.

Common Errors

NotImplementedError4 reports

The "NotImplementedError" in JAX usually arises when a function or operation is called for a data type, or a specific configuration (e.g., sparse arrays, custom autodiff rules, pallas kernels in sharded contexts), that hasn't been explicitly implemented in the JAX backend being used. To fix it, either implement the missing functionality for the given type/configuration, or explicitly use JAX primitives and operations defined for the supported data types, rewriting the code avoiding unsupported operations. Consider using `jax.experimental` libraries to work with experimental features or contributing the missing feature to the JAX project.

JaxRuntimeError4 reports

JaxRuntimeError often indicates an internal error within JAX's XLA compiler or underlying CUDA/cuDNN/cuSOLVER libraries. Try updating JAX, CUDA, and cuDNN to the latest compatible versions. If the error persists, consider disabling XLA autotuning via `jax.config.update("jax_disable_jit", True)` as a temporary workaround and report the issue with a minimal reproducible example to the JAX issue tracker.

XlaRuntimeError2 reports

XlaRuntimeError in JAX often arises from mismatches between the expected and actual data shapes or dtypes within XLA computations, especially during operations like FFTs or when specifying memory layouts. To resolve this, carefully inspect the input shapes, dtypes, and sharding specifications to ensure they align with XLA's requirements and constraints of the targeted device. Explicit dtype casting (e.g., using `jnp.asarray(x, dtype=jnp.complex64)`) or adjusting sharding annotations can often rectify these discrepancies.

KeyReuseError2 reports

The "KeyReuseError" in jax arises when the same PRNGKey is used multiple times within a jax transformation (like vmap or pmap) without being split, which leads to non-deterministic or potentially correlated random numbers. To fix this, ensure that each PRNGKey used within a transformed function is uniquely split using `jax.random.split` before each use, generating new keys for subsequent operations, effectively preventing key reuse. Always split your PRNGKey before passing it to JAX functions that use randomness within loops or vmapped code.

ShardingTypeError1 report

The "ShardingTypeError" often arises when JAX cannot infer the sharding for an operation, especially after type conversions or views that alter array layouts unexpectedly. To resolve this, explicitly specify a sharding rule (e.g., using `jax.Array(..., sharding=...)` or `jax.device_put(..., device=...)`) for the resulting array after the problematic operation, ensuring JAX knows how to distribute it across devices. Alternatively, revise the code to avoid the implicit type conversion or view if the intended sharding is ambiguous.

TypeError1 report

This error usually arises when JAX's automatic differentiation (autodiff) encounters a `ShapedArray` unexpectedly during the backward pass of a `fori_loop`, often triggered by operations incompatible with sharded arrays within the loop. Resolve this by explicitly materializing the `ShapedArray` into a concrete array using `jax.device_put` or `jax.block_until_ready` before its used in an operation that triggers the error, or restructuring the calculation to avoid the incompatible operation within the loop's backward pass. Review the sharding annotations to ensure data is properly distributed and materialized when needed.

Related Data & ML Packages

Subscribe to Updates

Get notified when new versions are released

RSS Feed