Change8

jax-v0.6.1

Breaking Changes
📦 jaxView on GitHub →
3 breaking1 features1 deprecations🔧 4 symbols

Summary

This release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.

⚠️ Breaking Changes

  • `jax.sharding.PartitionSpec` no longer inherits from a tuple, breaking code that relied on tuple behavior. Fix by treating it as a regular object and accessing its fields directly.
  • `jax.ShapeDtypeStruct` is now immutable; in‑place updates will raise errors. Use the `.update` method to create a modified copy.
  • Re‑enabled strict CUDA dependency version checks may cause installation failures if incompatible CUDA packages are present. Ensure your CUDA packages meet the required versions.

Migration Steps

  1. Update any code that treats `jax.sharding.PartitionSpec` as a tuple to use its explicit attributes.
  2. Replace in‑place modifications of `jax.ShapeDtypeStruct` with calls to its `.update` method.
  3. Verify that installed CUDA packages satisfy the version requirements now enforced by JAX.

✨ New Features

  • Added `jax.lax.axis_size` which returns the size of a mapped axis given its name.

🔧 Affected Symbols

jax.lax.axis_sizejax.sharding.PartitionSpecjax.ShapeDtypeStructjax.custom_derivatives.custom_jvp_call_jaxpr_p

⚡ Deprecations

  • `jax.custom_derivatives.custom_jvp_call_jaxpr_p` is deprecated and will be removed in JAX v0.7.0.