JAX
Data & MLComposable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Release History
jax-v0.9.0.1JAX v0.9.0.1 is a patch release identical to v0.9.0, incorporating fixes from four specific XLA pull requests.
jax-v0.8.32 fixesJAX v0.8.3 is released, incorporating two specific bug fixes that were missing from the previous v0.8.2 release.
jax-v0.9.0Breaking1 fix1 featureThis release introduces `jax.thread_guard` for multi-threaded JAX environments and updates the serialization format for `jax.export` to support explicit sharding, which tightens mesh matching requirements. Several configuration states and deprecated functions like `jax.numpy.fix` have been removed or deprecated.
jax-v0.8.2BreakingThis release deprecates several core and interpreter symbols, removes `jax.experimental.si_vjp`, and changes `Tracer` inheritance, requiring code updates to use `jax.vjp` and the new `pcast` API.
jax-v0.8.11 fix3 featuresJAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.
jax-v0.8.0Breaking5 featuresJAX introduces several breaking changes, including a new default implementation for `jax.pmap` and removal of many deprecated APIs, while adding new features such as namedtuple returns for eig and enhanced dlpack support.
jax-v0.7.2Breaking2 fixesJAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.
jax-v0.7.1Breaking5 featuresJAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.
jax-v0.7.0Breaking2 featuresJAX 0.7 introduces Shardy as the default execution model, updates autodiff to direct linearization, raises the minimum Python version to 3.11, and deprecates or removes several legacy APIs.
jax-v0.6.2Breaking1 featureThis release introduces the new `jax.tree.broadcast` helper and raises the minimum required versions of NumPy and SciPy.
jax-v0.6.1Breaking1 featureThis release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.
jax-v0.6.0Breaking4 featuresThis release removes several legacy tracing and configuration options, raises the minimum CUDA/CuDNN versions, updates package extras syntax, and deprecates many old APIs while introducing stricter `jax.jit` calling conventions.
jax-v0.5.32 featuresThis release introduces new options for slicing functions and categorical sampling in JAX, improving code size and adding support for sampling without replacement.
jax-v0.5.21 fixPatch 0.5.1 fixes TPU metric logging and the `tpu-info` command.
jax-v0.5.12 fixes3 featuresThis release adds experimental custom DCE support, new low‑level reduction ops, column‑pivoting QR, updates CPU collective defaults, removes the libtpu‑nightly dependency, deprecates internal functions requiring debug info, and includes TPU runtime and compilation cache fixes.
jax-v0.5.0Breaking2 fixes3 featuresJAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.
Common Errors
NotImplementedError4 reportsThe "NotImplementedError" in JAX usually arises when a function or operation is called for a data type, or a specific configuration (e.g., sparse arrays, custom autodiff rules, pallas kernels in sharded contexts), that hasn't been explicitly implemented in the JAX backend being used. To fix it, either implement the missing functionality for the given type/configuration, or explicitly use JAX primitives and operations defined for the supported data types, rewriting the code avoiding unsupported operations. Consider using `jax.experimental` libraries to work with experimental features or contributing the missing feature to the JAX project.
XlaRuntimeError2 reportsXlaRuntimeError in JAX often arises from mismatches between the expected and actual data shapes or dtypes within XLA computations, especially during operations like FFTs or when specifying memory layouts. To resolve this, carefully inspect the input shapes, dtypes, and sharding specifications to ensure they align with XLA's requirements and constraints of the targeted device. Explicit dtype casting (e.g., using `jnp.asarray(x, dtype=jnp.complex64)`) or adjusting sharding annotations can often rectify these discrepancies.
JaxRuntimeError2 reportsJaxRuntimeError often arises from inconsistencies between JAX's device context (GPU/TPU) and TensorFlow's, especially when using TF datasets or interoperating with TF code. Ensure JAX's device is initialized *before* any TensorFlow operations that might implicitly initialize a different device; try calling `jax.default_backend()` or `jax.devices()` early in your program to force JAX device initialization. Additionally, setting `TF_FORCE_UNIFIED_MEMORY=1` might resolve conflicts by enforcing memory allocation compatibility between JAX and TensorFlow.
FloatingPointError1 reportFloatingPointError in JAX often arises from operations like division by zero or overflow when using lower precision floats like float32, especially within functions like `dot_product_attention` if intermediate values become too large. To mitigate this, consider promoting the relevant data tensors to float64 to increase numerical range and precision, or use `jax.numpy.finfo` to determine minimum and maximum values for scaling or clamping input data to prevent overflows before potentially problematic operations. Regularly monitoring intermediate values during debugging can also help pinpoint the exact operation causing the issue.
ShardingTypeError1 reportThe "ShardingTypeError" often arises when JAX cannot infer the sharding for an operation, especially after type conversions or views that alter array layouts unexpectedly. To resolve this, explicitly specify a sharding rule (e.g., using `jax.Array(..., sharding=...)` or `jax.device_put(..., device=...)`) for the resulting array after the problematic operation, ensuring JAX knows how to distribute it across devices. Alternatively, revise the code to avoid the implicit type conversion or view if the intended sharding is ambiguous.
TypeError1 reportThis error usually arises when JAX's automatic differentiation (autodiff) encounters a `ShapedArray` unexpectedly during the backward pass of a `fori_loop`, often triggered by operations incompatible with sharded arrays within the loop. Resolve this by explicitly materializing the `ShapedArray` into a concrete array using `jax.device_put` or `jax.block_until_ready` before its used in an operation that triggers the error, or restructuring the calculation to avoid the incompatible operation within the loop's backward pass. Review the sharding annotations to ensure data is properly distributed and materialized when needed.
Related Data & ML Packages
An Open Source Machine Learning Framework for Everyone
🤗 Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
scikit-learn: machine learning in Python
Flexible and powerful data analysis / manipulation library for Python, providing labeled data structures similar to R data.frame objects, statistical functions, and much more
Streamlit — A faster way to build and share data apps.
Subscribe to Updates
Get notified when new versions are released