jax-v0.5.0
Breaking Changes📦 jaxView on GitHub →
⚠ 2 breaking✨ 3 features🐛 2 fixes⚡ 3 deprecations🔧 18 symbols
Summary
JAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.
⚠️ Breaking Changes
- Enabled `jax_threefry_partitionable` by default, changing PRNG key semantics; update code that relied on the previous default.
- Dropped support for Mac x86 wheels; users on that platform must switch to Mac ARM or another supported platform.
Migration Steps
- Review any code that relied on the previous PRNG key behavior and adjust for the default `jax_threefry_partitionable` setting.
- If you were using Mac x86 wheels, switch to a supported platform (e.g., Mac ARM) or build from source.
- Upgrade your environment to NumPy >=1.25 and SciPy >=1.11.
- Replace calls to `jax.numpy.linalg.solve` with the suggested workaround for batched 1‑D RHS.
- Update imports: use `jax.core.abstractify` and `jax.core.pytype_aval_mappings` instead of the `jax.interpreters.xla` versions.
- Replace deprecated `jax.scipy.special.lpmn` and `jax.scipy.special.lpmn_values` usages or remove them.
- Import FFI symbols from `jax.ffi` rather than `jax.extend.ffi`.
- Remove usage of the `jax_enable_memories` flag; the behavior is now always enabled.
- Replace references to `jax.lib.xla_client.Device` and `jax.lib.xla_client.XlaRuntimeError` with `jax.Device` and `jax.errors.JaxRuntimeError`.
- Migrate any code using `jax.experimental.array_api` to use `jax.numpy` directly.
✨ New Features
- `jax.numpy.fft.fftn`, `jax.numpy.fft.rfftn`, `jax.numpy.fft.ifftn`, and `jax.numpy.fft.irfftn` now support transforms in more than three dimensions.
- Added user‑defined state support in the FFI via the new `jax.ffi.register_ffi_type_id` function.
- AOT lowering `.as_text()` now accepts a `debug_info` option to include source location and other debugging information.
🐛 Bug Fixes
- `jax.numpy.einsum` default `optimize` changed from 'optimal' to 'auto' to avoid exponential trace‑time for many arguments.
- `jax.numpy.linalg.solve` no longer supports batched 1‑D RHS; use `solve(a, b[..., None]).squeeze(-1)` to emulate previous behavior.
🔧 Affected Symbols
jax_threefry_partitionablejax.numpy.einsumjax.numpy.linalg.solvejax.numpy.fft.fftnjax.numpy.fft.rfftnjax.numpy.fft.ifftnjax.numpy.fft.irfftnjax.ffi.register_ffi_type_idas_text (AOT lowering)jax.interpreters.xla.abstractifyjax.interpreters.xla.pytype_aval_mappingsjax.scipy.special.lpmnjax.scipy.special.lpmn_valuesjax.extend.ffijax_enable_memoriesjax.lib.xla_client.Devicejax.lib.xla_client.XlaRuntimeErrorjax.experimental.array_api⚡ Deprecations
- `jax.interpreters.xla.abstractify` and `jax.interpreters.xla.pytype_aval_mappings` are deprecated in favor of the same names in `jax.core`.
- `jax.scipy.special.lpmn` and `jax.scipy.special.lpmn_values` are deprecated with no replacement.
- `jax.extend.ffi` submodule has moved to `jax.ffi`; the old import path is deprecated.