jax-v0.7.1
Breaking Changes📦 jaxView on GitHub →
⚠ 3 breaking✨ 5 features⚡ 4 deprecations🔧 9 symbols
Summary
JAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.
⚠️ Breaking Changes
- Removed `jax.sharding.use_mesh`; code using it will break. Replace with `jax.set_mesh`.
- Importing `jax.experimental.host_callback` will raise ImportError starting in JAX v0.8.0. Update imports to avoid this module.
- Positional use of `precision` and `preferred_element_type` in `jax.lax.dot` is deprecated and will be removed; pass them as keyword arguments.
Migration Steps
- Replace any usage of `jax.sharding.use_mesh` with `jax.set_mesh`.
- Update imports to stop using `jax.experimental.host_callback`; remove or replace its usage before upgrading to v0.8.0.
- Change calls to `jax.lax.dot` that pass `precision` or `preferred_element_type` positionally to use keyword arguments.
- Replace `jax.lax.zeros_like_array` with `jax.numpy.zeros_like`.
- If your code relied on internal APIs from `jax.interpreters.ad`, `jax.interpreters.batching`, or `jax.interpreters.partial_eval`, refactor to avoid them as they are deprecated.
✨ New Features
- JAX now ships wheels for Python 3.14 and 3.14t.
- JAX now ships Python 3.13t and 3.14t wheels on macOS, adding free‑threading builds for macOS.
- Exposed `jax.set_mesh` as a global setter and context manager.
- JAX is now built using CUDA 12.9 while maintaining support for CUDA 12.1+.
- `jax.lax.dot` now supports general dot product via the optional `dimension_numbers` argument.
🔧 Affected Symbols
jax.set_meshjax.sharding.use_meshjax.lax.dotjax.lax.zeros_like_arrayjax.numpy.zeros_likejax.experimental.host_callbackjax.interpreters.adjax.interpreters.batchingjax.interpreters.partial_eval⚡ Deprecations
- `jax.lax.zeros_like_array` is deprecated; use `jax.numpy.zeros_like` instead.
- `jax.experimental.host_callback` import now emits DeprecationWarning and will become ImportError in v0.8.0.
- Positional arguments `precision` and `preferred_element_type` in `jax.lax.dot` are deprecated; use explicit keyword arguments.
- Several internal APIs in `jax.interpreters.ad`, `jax.interpreters.batching`, and `jax.interpreters.partial_eval` are deprecated without public replacements.