Change8

jax-v0.9.1

Breaking Changes
📦 jaxView on GitHub →
1 breaking1 features🐛 1 fixes🔧 2 symbols

Summary

This release improves strictness in `jax.shard_map` explicit mode by enforcing PartitionSpec matching and updates JAX tracer type reporting. A new debug configuration for the compilation cache was also introduced.

⚠️ Breaking Changes

  • JAX tracers that are not of `Array` type (e.g., of `Ref` type) no longer report themselves as instances of `Array`. Code relying on this incorrect reporting will need updating.

Migration Steps

  1. If using `jax.shard_map` in Explicit mode and relying on implicit resharding when `in_specs` is provided, explicitly call `jax.reshard` on inputs before passing them to `shard_map`.

✨ New Features

  • Added a debug config `jax_compilation_cache_check_contents`. When set, it enforces stricter checking of the compilation cache contents during `get()` and verifies contents during `put()` by the current process.

🐛 Bug Fixes

  • Using `jax.shard_map` in Explicit mode now raises an error if the PartitionSpec of input does not match the PartitionSpec specified in `in_specs`, enforcing explicit matching instead of implicit resharding.

Affected Symbols