Change8

JAX

Data & ML

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Latest: jax-v0.9.0.116 releases10 breaking changes8 common errorsView on GitHub

Release History

jax-v0.9.0.1
Feb 5, 2026

JAX v0.9.0.1 is a patch release identical to v0.9.0, incorporating fixes from four specific XLA pull requests.

jax-v0.8.32 fixes
Jan 29, 2026

JAX v0.8.3 is released, incorporating two specific bug fixes that were missing from the previous v0.8.2 release.

jax-v0.9.0Breaking1 fix1 feature
Jan 20, 2026

This release introduces `jax.thread_guard` for multi-threaded JAX environments and updates the serialization format for `jax.export` to support explicit sharding, which tightens mesh matching requirements. Several configuration states and deprecated functions like `jax.numpy.fix` have been removed or deprecated.

jax-v0.8.2Breaking
Dec 18, 2025

This release deprecates several core and interpreter symbols, removes `jax.experimental.si_vjp`, and changes `Tracer` inheritance, requiring code updates to use `jax.vjp` and the new `pcast` API.

jax-v0.8.11 fix3 features
Nov 18, 2025

JAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.

jax-v0.8.0Breaking5 features
Oct 15, 2025

JAX introduces several breaking changes, including a new default implementation for `jax.pmap` and removal of many deprecated APIs, while adding new features such as namedtuple returns for eig and enhanced dlpack support.

jax-v0.7.2Breaking2 fixes
Sep 16, 2025

JAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.

jax-v0.7.1Breaking5 features
Aug 20, 2025

JAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.

jax-v0.7.0Breaking2 features
Jul 22, 2025

JAX 0.7 introduces Shardy as the default execution model, updates autodiff to direct linearization, raises the minimum Python version to 3.11, and deprecates or removes several legacy APIs.

jax-v0.6.2Breaking1 feature
Jun 17, 2025

This release introduces the new `jax.tree.broadcast` helper and raises the minimum required versions of NumPy and SciPy.

jax-v0.6.1Breaking1 feature
May 21, 2025

This release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.

jax-v0.6.0Breaking4 features
Apr 17, 2025

This release removes several legacy tracing and configuration options, raises the minimum CUDA/CuDNN versions, updates package extras syntax, and deprecates many old APIs while introducing stricter `jax.jit` calling conventions.

jax-v0.5.32 features
Mar 19, 2025

This release introduces new options for slicing functions and categorical sampling in JAX, improving code size and adding support for sampling without replacement.

jax-v0.5.21 fix
Mar 5, 2025

Patch 0.5.1 fixes TPU metric logging and the `tpu-info` command.

jax-v0.5.12 fixes3 features
Feb 24, 2025

This release adds experimental custom DCE support, new low‑level reduction ops, column‑pivoting QR, updates CPU collective defaults, removes the libtpu‑nightly dependency, deprecates internal functions requiring debug info, and includes TPU runtime and compilation cache fixes.

jax-v0.5.0Breaking2 fixes3 features
Jan 17, 2025

JAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.

Common Errors

NotImplementedError4 reports

The "NotImplementedError" in JAX usually arises when a function or operation is called for a data type, or a specific configuration (e.g., sparse arrays, custom autodiff rules, pallas kernels in sharded contexts), that hasn't been explicitly implemented in the JAX backend being used. To fix it, either implement the missing functionality for the given type/configuration, or explicitly use JAX primitives and operations defined for the supported data types, rewriting the code avoiding unsupported operations. Consider using `jax.experimental` libraries to work with experimental features or contributing the missing feature to the JAX project.

XlaRuntimeError2 reports

XlaRuntimeError in JAX often arises from mismatches between the expected and actual data shapes or dtypes within XLA computations, especially during operations like FFTs or when specifying memory layouts. To resolve this, carefully inspect the input shapes, dtypes, and sharding specifications to ensure they align with XLA's requirements and constraints of the targeted device. Explicit dtype casting (e.g., using `jnp.asarray(x, dtype=jnp.complex64)`) or adjusting sharding annotations can often rectify these discrepancies.

JaxRuntimeError2 reports

JaxRuntimeError often arises from inconsistencies between JAX's device context (GPU/TPU) and TensorFlow's, especially when using TF datasets or interoperating with TF code. Ensure JAX's device is initialized *before* any TensorFlow operations that might implicitly initialize a different device; try calling `jax.default_backend()` or `jax.devices()` early in your program to force JAX device initialization. Additionally, setting `TF_FORCE_UNIFIED_MEMORY=1` might resolve conflicts by enforcing memory allocation compatibility between JAX and TensorFlow.

FloatingPointError1 report

FloatingPointError in JAX often arises from operations like division by zero or overflow when using lower precision floats like float32, especially within functions like `dot_product_attention` if intermediate values become too large. To mitigate this, consider promoting the relevant data tensors to float64 to increase numerical range and precision, or use `jax.numpy.finfo` to determine minimum and maximum values for scaling or clamping input data to prevent overflows before potentially problematic operations. Regularly monitoring intermediate values during debugging can also help pinpoint the exact operation causing the issue.

ShardingTypeError1 report

The "ShardingTypeError" often arises when JAX cannot infer the sharding for an operation, especially after type conversions or views that alter array layouts unexpectedly. To resolve this, explicitly specify a sharding rule (e.g., using `jax.Array(..., sharding=...)` or `jax.device_put(..., device=...)`) for the resulting array after the problematic operation, ensuring JAX knows how to distribute it across devices. Alternatively, revise the code to avoid the implicit type conversion or view if the intended sharding is ambiguous.

TypeError1 report

This error usually arises when JAX's automatic differentiation (autodiff) encounters a `ShapedArray` unexpectedly during the backward pass of a `fori_loop`, often triggered by operations incompatible with sharded arrays within the loop. Resolve this by explicitly materializing the `ShapedArray` into a concrete array using `jax.device_put` or `jax.block_until_ready` before its used in an operation that triggers the error, or restructuring the calculation to avoid the incompatible operation within the loop's backward pass. Review the sharding annotations to ensure data is properly distributed and materialized when needed.

Related Data & ML Packages

Subscribe to Updates

Get notified when new versions are released

RSS Feed