jax-v0.6.1
Breaking Changes📦 jaxView on GitHub →
⚠ 3 breaking✨ 1 features⚡ 1 deprecations🔧 4 symbols
Summary
This release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.
⚠️ Breaking Changes
- `jax.sharding.PartitionSpec` no longer inherits from a tuple, breaking code that relied on tuple behavior. Fix by treating it as a regular object and accessing its fields directly.
- `jax.ShapeDtypeStruct` is now immutable; in‑place updates will raise errors. Use the `.update` method to create a modified copy.
- Re‑enabled strict CUDA dependency version checks may cause installation failures if incompatible CUDA packages are present. Ensure your CUDA packages meet the required versions.
Migration Steps
- Update any code that treats `jax.sharding.PartitionSpec` as a tuple to use its explicit attributes.
- Replace in‑place modifications of `jax.ShapeDtypeStruct` with calls to its `.update` method.
- Verify that installed CUDA packages satisfy the version requirements now enforced by JAX.
✨ New Features
- Added `jax.lax.axis_size` which returns the size of a mapped axis given its name.
🔧 Affected Symbols
jax.lax.axis_sizejax.sharding.PartitionSpecjax.ShapeDtypeStructjax.custom_derivatives.custom_jvp_call_jaxpr_p⚡ Deprecations
- `jax.custom_derivatives.custom_jvp_call_jaxpr_p` is deprecated and will be removed in JAX v0.7.0.