jax-v0.8.0
Breaking Changes📦 jaxView on GitHub →
⚠ 19 breaking✨ 5 features⚡ 3 deprecations🔧 36 symbols
Summary
JAX introduces several breaking changes, including a new default implementation for `jax.pmap` and removal of many deprecated APIs, while adding new features such as namedtuple returns for eig and enhanced dlpack support.
⚠️ Breaking Changes
- Default implementation of `jax.pmap` switched to use `jax.jit` and `jax.shard_map`; new code should call `jax.shard_map` directly (see migration guide).
- `auto=` parameter removed from `jax.experimental.shard_map.shard_map`; nesting no longer supported – use `jax.shard_map` for nested calls.
- Objects implementing `__jax_array__` can no longer be passed directly to JIT‑compiled functions; wrap them with `jax.numpy.asarray` first.
- `jax.numpy.cov` now returns NaN for empty arrays and matches NumPy 2.2 behavior for single‑row design matrices.
- `Array` values are no longer accepted where a `dtype` is expected; extract the dtype via `.dtype` before passing.
- Removed deprecated function `jax.interpreters.mlir.custom_call`.
- Modules `jax.util`, `jax.extend.ffi`, and `jax.experimental.host_callback` have been removed.
- Removed deprecated symbol `jax.custom_derivatives.custom_jvp_call_jaxpr_p`.
- `jax.experimental.multihost_utils.process_allgather` now errors when given a non‑fully‑addressable `jax.Array` with `tiled=False`; pass `tiled=True`.
- Deprecated symbols `is_initialized` and `initialize_cache` removed from `jax.experimental.compilation_cache`.
- Removed deprecated function `jax.interpreters.xla.canonicalize_dtype`.
- `jaxlib.hlo_helpers` removed; use `jax.ffi` instead.
- Option `jax_cpu_enable_gloo_collectives` removed; use `jax_cpu_collectives_implementation`.
- `interpolation` argument removed from `jax.numpy.percentile` and `jax.numpy.quantile`; use `method` argument.
- Internal `for_loop` primitive removed; use `jax.lax.fori_loop` directly.
- `jax.numpy.trimzeros` now raises an error for non‑1D input.
- `where` argument to reduction functions like `jax.numpy.sum` must now be boolean.
- Removed deprecated functions in `jax.dlpack`, `jax.errors`, `jax.lib.xla_bridge`, `jax.lib.xla_client`, and `jax.lib.xla_extension`.
- Removed `jax.interpreters.mlir.dense_bool_array`; construct attributes via MLIR APIs.
Migration Steps
- Update code to use `jax.shard_map` instead of `jax.pmap` or the deprecated `jax.experimental.shard_map.shard_map`.
- If you relied on nesting `shard_map`, switch to `jax.shard_map` for nested calls.
- Wrap objects with `__jax_array__` using `jax.numpy.asarray` before passing to JIT‑compiled functions.
- When calling `jax.experimental.multihost_utils.process_allgather` with a `jax.Array`, add `tiled=True`.
- Replace imports from removed modules (`jax.util`, `jax.extend.ffi`, `jax.experimental.host_callback`) with appropriate alternatives (e.g., `jax.ffi`).
- Use `jax.ffi` instead of the removed `jaxlib.hlo_helpers`.
- Switch from the removed `jax_cpu_enable_gloo_collectives` option to `jax_cpu_collectives_implementation`.
- Change calls to `jax.numpy.percentile` and `jax.numpy.quantile` to use the `method` argument rather than `interpolation`.
- Replace any use of the internal `for_loop` primitive with `jax.lax.fori_loop`.
- Ensure the `where` argument passed to reduction ops like `jax.numpy.sum` is a boolean array.
- Update any code that used the removed deprecated functions in `jax.dlpack`, `jax.errors`, `jax.lib.xla_*` modules to the current public APIs.
- If you used `jax.interpreters.mlir.dense_bool_array`, construct MLIR attributes via the MLIR APIs instead.
✨ New Features
- `jax.numpy.linalg.eig` now returns a namedtuple with fields `eigenvalues` and `eigenvectors` instead of a plain tuple.
- `jax.grad` and `jax.vjp` now always round primals to `float32` when float64 mode is disabled.
- `jax.dlpack.from_dlpack` now accepts arrays with non‑default layouts, such as transposed arrays.
- Default nonsymmetric eigendecomposition on NVIDIA GPUs now uses cuSOLVER; alternative implementations selectable via new `implementation` argument to `jax.lax.linalg.eig` (the `use_magma` argument is deprecated).
- `jax.numpy.trim_zeros` now supports multi‑dimensional inputs, matching NumPy 2.2 behavior.
🔧 Affected Symbols
jax.pmapjax.experimental.shard_map.shard_mapjax.numpy.covjax.interpreters.mlir.custom_calljax.utiljax.extend.ffijax.experimental.host_callbackjax.custom_derivatives.custom_jvp_call_jaxpr_pjax.experimental.multihost_utils.process_allgatherjax.experimental.compilation_cache.is_initializedjax.experimental.compilation_cache.initialize_cachejax.interpreters.xla.canonicalize_dtypejaxlib.hlo_helpersjax_cpu_enable_gloo_collectivesjax.numpy.percentilejax.numpy.quantilefor_loopjax.numpy.trimzerosjax.numpy.sumjax.dlpackjax.errorsjax.lib.xla_bridgejax.lib.xla_clientjax.lib.xla_extensionjax.interpreters.mlir.dense_bool_arrayjax.numpy.linalg.eigjax.gradjax.vjpjax.dlpack.from_dlpackjax.lax.linalg.eigjax.numpy.trim_zerosjax.enable_x64jax.experimental.enable_x64jax.experimental.disable_x64jax.experimental.pjit.pjitjax.jit⚡ Deprecations
- `jax.experimental.enable_x64` and `jax.experimental.disable_x64` are deprecated; use the new context manager `jax.enable_x64`.
- `jax.experimental.shard_map.shard_map` is deprecated; replace with `jax.shard_map`.
- `jax.experimental.pjit.pjit` is deprecated; replace with `jax.jit`.