jax-v0.7.2
Breaking Changes📦 jax
⚠ 1 breaking🐛 2 fixes⚡ 2 deprecations🔧 5 symbols
Summary
JAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.
⚠️ Breaking Changes
- jax.dlpack.from_dlpack no longer accepts a DLPack capsule; it must be called with an array implementing __dlpack__ and __dlpack_device__.
Migration Steps
- Update calls to jax.dlpack.from_dlpack to pass an object that implements __dlpack__ and __dlpack_device__ instead of a raw DLPack capsule.
- If code relies on isinstance(x, np.ndarray) for JAX constants, convert the value with np.asarray(x) to obtain a classic NumPy array.
- Remove usage of the enable_xla and native_serialization arguments from jax2tf.convert calls.
- Stop setting jax_pmap_no_rank_reduction to False; rely on the default True behavior.
🐛 Bug Fixes
- arr.view(dtype=None) now returns the array unchanged, matching NumPy semantics.
- jax.random.randint now produces a less‑biased distribution for 8‑bit and 16‑bit integer types; the previous biased behavior can be restored by setting the jax_safer_randint config to False (temporary).
🔧 Affected Symbols
jax.dlpack.from_dlpackjax2tf.convertjax_pmap_no_rank_reductionjax.random.randintLiteralArray⚡ Deprecations
- Parameters enable_xla and native_serialization for jax2tf.convert are deprecated and will be removed in a future version.
- Setting the config state jax_pmap_no_rank_reduction to False is deprecated; the default will be True.