Change8

jax-v0.7.2

Breaking Changes
📦 jax
1 breaking🐛 2 fixes2 deprecations🔧 5 symbols

Summary

JAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.

⚠️ Breaking Changes

  • jax.dlpack.from_dlpack no longer accepts a DLPack capsule; it must be called with an array implementing __dlpack__ and __dlpack_device__.

Migration Steps

  1. Update calls to jax.dlpack.from_dlpack to pass an object that implements __dlpack__ and __dlpack_device__ instead of a raw DLPack capsule.
  2. If code relies on isinstance(x, np.ndarray) for JAX constants, convert the value with np.asarray(x) to obtain a classic NumPy array.
  3. Remove usage of the enable_xla and native_serialization arguments from jax2tf.convert calls.
  4. Stop setting jax_pmap_no_rank_reduction to False; rely on the default True behavior.

🐛 Bug Fixes

  • arr.view(dtype=None) now returns the array unchanged, matching NumPy semantics.
  • jax.random.randint now produces a less‑biased distribution for 8‑bit and 16‑bit integer types; the previous biased behavior can be restored by setting the jax_safer_randint config to False (temporary).

🔧 Affected Symbols

jax.dlpack.from_dlpackjax2tf.convertjax_pmap_no_rank_reductionjax.random.randintLiteralArray

⚡ Deprecations

  • Parameters enable_xla and native_serialization for jax2tf.convert are deprecated and will be removed in a future version.
  • Setting the config state jax_pmap_no_rank_reduction to False is deprecated; the default will be True.