jax-v0.9.1
Breaking Changes📦 jaxView on GitHub →
⚠ 1 breaking✨ 1 features🐛 1 fixes🔧 2 symbols
Summary
This release improves strictness in `jax.shard_map` explicit mode by enforcing PartitionSpec matching and updates JAX tracer type reporting. A new debug configuration for the compilation cache was also introduced.
⚠️ Breaking Changes
- JAX tracers that are not of `Array` type (e.g., of `Ref` type) no longer report themselves as instances of `Array`. Code relying on this incorrect reporting will need updating.
Migration Steps
- If using `jax.shard_map` in Explicit mode and relying on implicit resharding when `in_specs` is provided, explicitly call `jax.reshard` on inputs before passing them to `shard_map`.
✨ New Features
- Added a debug config `jax_compilation_cache_check_contents`. When set, it enforces stricter checking of the compilation cache contents during `get()` and verifies contents during `put()` by the current process.
🐛 Bug Fixes
- Using `jax.shard_map` in Explicit mode now raises an error if the PartitionSpec of input does not match the PartitionSpec specified in `in_specs`, enforcing explicit matching instead of implicit resharding.