Change8

jax-v0.5.0

Breaking Changes
📦 jaxView on GitHub →
2 breaking3 features🐛 2 fixes3 deprecations🔧 18 symbols

Summary

JAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.

⚠️ Breaking Changes

  • Enabled `jax_threefry_partitionable` by default, changing PRNG key semantics; update code that relied on the previous default.
  • Dropped support for Mac x86 wheels; users on that platform must switch to Mac ARM or another supported platform.

Migration Steps

  1. Review any code that relied on the previous PRNG key behavior and adjust for the default `jax_threefry_partitionable` setting.
  2. If you were using Mac x86 wheels, switch to a supported platform (e.g., Mac ARM) or build from source.
  3. Upgrade your environment to NumPy >=1.25 and SciPy >=1.11.
  4. Replace calls to `jax.numpy.linalg.solve` with the suggested workaround for batched 1‑D RHS.
  5. Update imports: use `jax.core.abstractify` and `jax.core.pytype_aval_mappings` instead of the `jax.interpreters.xla` versions.
  6. Replace deprecated `jax.scipy.special.lpmn` and `jax.scipy.special.lpmn_values` usages or remove them.
  7. Import FFI symbols from `jax.ffi` rather than `jax.extend.ffi`.
  8. Remove usage of the `jax_enable_memories` flag; the behavior is now always enabled.
  9. Replace references to `jax.lib.xla_client.Device` and `jax.lib.xla_client.XlaRuntimeError` with `jax.Device` and `jax.errors.JaxRuntimeError`.
  10. Migrate any code using `jax.experimental.array_api` to use `jax.numpy` directly.

✨ New Features

  • `jax.numpy.fft.fftn`, `jax.numpy.fft.rfftn`, `jax.numpy.fft.ifftn`, and `jax.numpy.fft.irfftn` now support transforms in more than three dimensions.
  • Added user‑defined state support in the FFI via the new `jax.ffi.register_ffi_type_id` function.
  • AOT lowering `.as_text()` now accepts a `debug_info` option to include source location and other debugging information.

🐛 Bug Fixes

  • `jax.numpy.einsum` default `optimize` changed from 'optimal' to 'auto' to avoid exponential trace‑time for many arguments.
  • `jax.numpy.linalg.solve` no longer supports batched 1‑D RHS; use `solve(a, b[..., None]).squeeze(-1)` to emulate previous behavior.

🔧 Affected Symbols

jax_threefry_partitionablejax.numpy.einsumjax.numpy.linalg.solvejax.numpy.fft.fftnjax.numpy.fft.rfftnjax.numpy.fft.ifftnjax.numpy.fft.irfftnjax.ffi.register_ffi_type_idas_text (AOT lowering)jax.interpreters.xla.abstractifyjax.interpreters.xla.pytype_aval_mappingsjax.scipy.special.lpmnjax.scipy.special.lpmn_valuesjax.extend.ffijax_enable_memoriesjax.lib.xla_client.Devicejax.lib.xla_client.XlaRuntimeErrorjax.experimental.array_api

⚡ Deprecations

  • `jax.interpreters.xla.abstractify` and `jax.interpreters.xla.pytype_aval_mappings` are deprecated in favor of the same names in `jax.core`.
  • `jax.scipy.special.lpmn` and `jax.scipy.special.lpmn_values` are deprecated with no replacement.
  • `jax.extend.ffi` submodule has moved to `jax.ffi`; the old import path is deprecated.