jax-v0.9.0
Breaking Changes📦 jaxView on GitHub →
⚠ 1 breaking✨ 1 features🐛 1 fixes⚡ 4 deprecations🔧 8 symbols
Summary
This release introduces `jax.thread_guard` for multi-threaded JAX environments and updates the serialization format for `jax.export` to support explicit sharding, which tightens mesh matching requirements. Several configuration states and deprecated functions like `jax.numpy.fix` have been removed or deprecated.
⚠️ Breaking Changes
- The behavior of exported modules when called has changed due to a new export serialization format. When calling an exported module, the abstract mesh used must now exactly match the one used at export time, including axis names. Previously, only the number of devices needed to match.
Migration Steps
- If using exported modules, ensure that the abstract mesh used during the call matches the abstract mesh used during export, including axis names.
- Replace usage of `jax.numpy.fix` with `jax.numpy.trunc`.
✨ New Features
- Added `jax.thread_guard`, a context manager to detect device usage by multiple threads in multi-controller JAX.
🐛 Bug Fixes
- Fixed a workspace size calculation error for pivoted QR (`magma_zgeqp3_gpu`) when using `use_magma=True` and `pivoting=True` in MAGMA 2.9.0.
Affected Symbols
⚡ Deprecations
- The flag `jax_collectives_common_channel_id` has been removed.
- The `jax_pmap_no_rank_reduction` config state has been removed; the no-rank-reduction behavior is now the only supported behavior where a `jax.pmap`ped function sees inputs of the same rank as the input to `jax.pmap(f)`.
- Setting the `jax_pmap_shmap_merge` config state is deprecated and will be removed in JAX v0.10.0.
- `jax.numpy.fix` is deprecated; use `jax.numpy.trunc` as a drop-in replacement.