Change8

jax-v0.7.1

Breaking Changes
📦 jaxView on GitHub →
3 breaking5 features4 deprecations🔧 9 symbols

Summary

JAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.

⚠️ Breaking Changes

  • Removed `jax.sharding.use_mesh`; code using it will break. Replace with `jax.set_mesh`.
  • Importing `jax.experimental.host_callback` will raise ImportError starting in JAX v0.8.0. Update imports to avoid this module.
  • Positional use of `precision` and `preferred_element_type` in `jax.lax.dot` is deprecated and will be removed; pass them as keyword arguments.

Migration Steps

  1. Replace any usage of `jax.sharding.use_mesh` with `jax.set_mesh`.
  2. Update imports to stop using `jax.experimental.host_callback`; remove or replace its usage before upgrading to v0.8.0.
  3. Change calls to `jax.lax.dot` that pass `precision` or `preferred_element_type` positionally to use keyword arguments.
  4. Replace `jax.lax.zeros_like_array` with `jax.numpy.zeros_like`.
  5. If your code relied on internal APIs from `jax.interpreters.ad`, `jax.interpreters.batching`, or `jax.interpreters.partial_eval`, refactor to avoid them as they are deprecated.

✨ New Features

  • JAX now ships wheels for Python 3.14 and 3.14t.
  • JAX now ships Python 3.13t and 3.14t wheels on macOS, adding free‑threading builds for macOS.
  • Exposed `jax.set_mesh` as a global setter and context manager.
  • JAX is now built using CUDA 12.9 while maintaining support for CUDA 12.1+.
  • `jax.lax.dot` now supports general dot product via the optional `dimension_numbers` argument.

🔧 Affected Symbols

jax.set_meshjax.sharding.use_meshjax.lax.dotjax.lax.zeros_like_arrayjax.numpy.zeros_likejax.experimental.host_callbackjax.interpreters.adjax.interpreters.batchingjax.interpreters.partial_eval

⚡ Deprecations

  • `jax.lax.zeros_like_array` is deprecated; use `jax.numpy.zeros_like` instead.
  • `jax.experimental.host_callback` import now emits DeprecationWarning and will become ImportError in v0.8.0.
  • Positional arguments `precision` and `preferred_element_type` in `jax.lax.dot` are deprecated; use explicit keyword arguments.
  • Several internal APIs in `jax.interpreters.ad`, `jax.interpreters.batching`, and `jax.interpreters.partial_eval` are deprecated without public replacements.