Change8

jax-v0.6.0

Breaking Changes
📦 jaxView on GitHub →
5 breaking4 features9 deprecations🔧 43 symbols

Summary

This release removes several legacy tracing and configuration options, raises the minimum CUDA/CuDNN versions, updates package extras syntax, and deprecates many old APIs while introducing stricter `jax.jit` calling conventions.

⚠️ Breaking Changes

  • `jax.numpy.array` no longer accepts `None`; code passing None will raise a TypeError. Replace None arguments with valid arrays or guard against None.
  • The `config.jax_data_dependent_tracing_fallback` option has been removed; remove any usage or set appropriate tracing configuration.
  • The `config.jax_eager_pmap` option has been removed; update code to not rely on it.
  • Calling `lower` or `trace` AOT APIs on a `jax.jit` (or `jax.pmap`) result after additional wrappers is now disallowed. Apply `jax.jit`/`jax.pmap` as the last wrapper.
  • The `cuda12_pip` extra has been removed; install with `pip install jax[cuda12]` instead.

Migration Steps

  1. Replace any `jax.numpy.array(None, …)` calls with valid array creation or guard against None.
  2. Remove references to `config.jax_data_dependent_tracing_fallback` and `config.jax_eager_pmap` from configuration files.
  3. If you wrap a `jax.jit` (or `jax.pmap`) result with other transformations, ensure `jax.jit`/`jax.pmap` is applied last.
  4. Install JAX with the new extra syntax, e.g., `pip install jax[cuda12]` or `pip install jax[cuda12-local]`.
  5. Update `jax.jit` calls to pass the function argument positionally and use keyword arguments for other parameters.
  6. Switch any usage of deprecated APIs to their recommended replacements (e.g., `jax.tree.unflatten`, `jax.ffi`, `jax.extend.mlir`).
  7. Upgrade your CUDA toolkit to at least 12.1 and CuDNN to 9.8 to meet the new minimum requirements.

✨ New Features

  • Minimum CuDNN version raised to v9.8.
  • JAX is now built with CUDA 12.8 while still supporting CUDA 12.1+.
  • Package extras now use dash instead of underscore per PEP 685 (e.g., `jax[cuda12-local]`).
  • `jax.jit` now requires the function argument `fun` positionally and other arguments keyword‑only; non‑conforming calls will emit a DeprecationWarning in v0.6.X and an error in v0.7.X.

🔧 Affected Symbols

jax.numpy.arrayconfig.jax_data_dependent_tracing_fallbackconfig.jax_eager_pmapjax.jitjax.pmapjax.jit.lowerjax.jit.tracejax.pmap.lowerjax.pmap.tracejax.lib.xla_extensionjax.interpreters.mlir.hlojax.interpreters.mlir.func_dialectjax.interpreters.mlir.custom_calljax.ffi.ffi_calljax.lib.xla_client.get_topology_for_devicesjax.lib.xla_client.heap_profilejax.lib.xla_client.mlir_api_versionjax.lib.xla_client.Clientjax.lib.xla_client.CompileOptionsjax.lib.xla_client.DeviceAssignmentjax.lib.xla_client.Framejax.lib.xla_client.HloShardingjax.lib.xla_client.OpShardingjax.lib.xla_client.Tracebackjax.util.HashableFunctionjax.util.as_hashable_functionjax.util.cachejax.util.safe_mapjax.util.safe_zipjax.util.split_dictjax.util.split_listjax.util.split_list_checkedjax.util.split_mergejax.util.subvalsjax.util.toposortjax.util.unzip2jax.util.wrap_namejax.util.wrapsjax.dlpack.to_dlpackjax.lax.infeedjax.lax.infeed_pjax.lax.outfeedjax.lax.outfeed_p

⚡ Deprecations

  • `jax.tree_util.build_tree` is deprecated; use `jax.tree.unflatten`.
  • All APIs in `jax.lib.xla_extension` are deprecated.
  • `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect` have been removed; import from `jax.extend.mlir` if needed.
  • `jax.interpreters.mlir.custom_call` is deprecated; use `jax.ffi` APIs.
  • `jax.ffi.ffi_call` no longer supports inline arguments and now always returns a callable.
  • Exports in `jax.lib.xla_client` (`get_topology_for_devices`, `heap_profile`, `mlir_api_version`, `Client`, `CompileOptions`, `DeviceAssignment`, `Frame`, `HloSharding`, `OpSharding`, `Traceback`) are deprecated.
  • Internal APIs in `jax.util` (`HashableFunction`, `as_hashable_function`, `cache`, `safe_map`, `safe_zip`, `split_dict`, `split_list`, `split_list_checked`, `split_merge`, `subvals`, `toposort`, `unzip2`, `wrap_name`, `wraps`) are deprecated.
  • `jax.dlpack.to_dlpack` is deprecated; use the array’s `__dlpack__` attribute.
  • `jax.lax.infeed`, `jax.lax.infeed_p`, `jax.lax.outfeed`, and `jax.lax.outfeed_p` are deprecated and will be removed in v0.7.0.