Change8

jax-v0.7.0

Breaking Changes
📦 jaxView on GitHub →
10 breaking2 features8 deprecations🔧 43 symbols

Summary

JAX 0.7 introduces Shardy as the default execution model, updates autodiff to direct linearization, raises the minimum Python version to 3.11, and deprecates or removes several legacy APIs.

⚠️ Breaking Changes

  • JAX is migrating from GSPMD to Shardy by default; see the migration guide at https://docs.jax.dev/en/latest/shardy_jax_migration.html for details.
  • JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval); see the migration guide at https://docs.jax.dev/en/latest/direct_linearize_migration.html.
  • `jax.stages.OutInfo` has been replaced with `jax.ShapeDtypeStruct`.
  • `jax.jit` now requires the `fun` argument to be passed positionally and any additional arguments to be passed by keyword; passing them otherwise will raise an error starting in v0.7.x (was a DeprecationWarning in v0.6.x).
  • The minimum supported Python version is now 3.11 and will remain so until July 2026.
  • Layout API renames: `Layout`, `.layout`, `.input_layouts` and `.output_layouts` have been renamed to `Format`, `.format`, `.input_formats` and `.output_formats`; `DeviceLocalLayout` and `.device_local_layout` have been renamed to `Layout` and `.layout`.
  • `jax.experimental.shard` module has been deleted; all its APIs have moved to `jax.sharding`. Use `jax.sharding.reshard`, `jax.sharding.auto_axes` and `jax.sharding.explicit_axes` instead of the experimental endpoints.
  • `lax.infeed` and `lax.outfeed` were removed, as were the `transfer_to_infeed` and `transfer_from_outfeed` methods on `Device` objects.
  • `jax.extend.core.primitives.pjit_p` primitive has been renamed to `jit_p` and its `name` attribute changed from "pjit" to "jit"; it is no longer exported from `jax.experimental.pjit`.
  • The undocumented function `jax.extend.backend.add_clear_backends_callback` has been removed; use `jax.extend.backend.register_backend_cache` instead.

Migration Steps

  1. Follow the GSPMD‑to‑Shardy migration guide at https://docs.jax.dev/en/latest/shardy_jax_migration.html.
  2. Update code to use direct linearization according to the guide at https://docs.jax.dev/en/latest/direct_linearize_migration.html.
  3. Replace any use of `jax.stages.OutInfo` with `jax.ShapeDtypeStruct`.
  4. Call `jax.jit` with the function argument positionally and pass other arguments as keyword arguments.
  5. Rename Layout API symbols: replace `Layout` with `Format`, `.layout` with `.format`, `.input_layouts` with `.input_formats`, `.output_layouts` with `.output_formats`; replace `DeviceLocalLayout` and `.device_local_layout` with `Layout` and `.layout`.
  6. Replace imports from `jax.experimental.shard` with `jax.sharding` and use `jax.sharding.reshard`, `jax.sharding.auto_axes`, and `jax.sharding.explicit_axes`.
  7. Remove any usage of `lax.infeed`, `lax.outfeed`, and the `transfer_to_infeed` / `transfer_from_outfeed` methods on `Device` objects.
  8. Update primitive references: replace `jax.extend.core.primitives.pjit_p` with `jit_p` and stop importing from `jax.experimental.pjit`.
  9. Replace calls to `jax.extend.backend.add_clear_backends_callback` with `jax.extend.backend.register_backend_cache`.

✨ New Features

  • Added `jax.P` as an alias for `jax.sharding.PartitionSpec`.
  • Added `jax.tree.reduce_associative`.

🔧 Affected Symbols

jax.stages.OutInfojax.ShapeDtypeStructjax.jitLayoutDeviceLocalLayoutjax.experimental.shardjax.sharding.reshardjax.sharding.auto_axesjax.sharding.explicit_axeslax.infeedlax.outfeedDevice.transfer_to_infeedDevice.transfer_from_outfeedjax.extend.core.primitives.pjit_pjit_pjax.experimental.pjitjax.extend.backend.add_clear_backends_callbackjax.extend.backend.register_backend_cachejax.dlpack.SUPPORTED_DTYPESjax.dlpack.is_supported_dtypejax.scipy.special.sph_harmjax.scipy.special.sph_harm_yjax.interpreters.xla.abstractifyjax.interpreters.xla.pytype_aval_mappingsjax.interpreters.xla.canonicalize_dtypejax.dtypes.canonicalize_dtypejax.core.valid_jaxtypejax.core.AxisNamejax.core.ConcretizationTypeErrorjax.core.axis_framejax.core.call_pjax.core.closed_call_pjax.core.get_typejax.core.trace_state_cleanjax.core.typematchjax.core.typecheckjax.lib.xla_client.DeviceAssignmentjax.lib.xla_client.get_topology_for_devicesjax.lib.xla_client.mlir_api_versionjax.extend.ffijax.ffijax.lib.xla_bridge.get_compile_optionsjax.extend.backend.get_compile_options

⚡ Deprecations

  • `jax.dlpack.SUPPORTED_DTYPES` is deprecated; use `jax.dlpack.is_supported_dtype` instead.
  • `jax.scipy.special.sph_harm` is deprecated; use `jax.scipy.special.sph_harm_y` instead.
  • From `jax.interpreters.xla`, the previously deprecated symbols `abstractify` and `pytype_aval_mappings` have been removed.
  • `jax.interpreters.xla.canonicalize_dtype` is deprecated; prefer `jax.dtypes.canonicalize_dtype` for dtype canonicalization and `jax.core.valid_jaxtype` for input validation.
  • From `jax.core`, the previously deprecated symbols `AxisName`, `ConcretizationTypeError`, `axis_frame`, `call_p`, `closed_call_p`, `get_type`, `trace_state_clean`, `typematch`, and `typecheck` have been removed.
  • From `jax.lib.xla_client`, the previously deprecated symbols `DeviceAssignment`, `get_topology_for_devices`, and `mlir_api_version` have been removed.
  • `jax.extend.ffi` was removed after being deprecated in v0.5.0; use `jax.ffi` instead.
  • `jax.lib.xla_bridge.get_compile_options` is deprecated and replaced by `jax.extend.backend.get_compile_options`.