jax-v0.9.2
Breaking Changes📦 jaxView on GitHub →
⚠ 2 breaking🔧 2 symbols
Summary
JAX 0.9.2 updates the internal type structure of `TypedNdArray` and changes the execution context for `jnp.arange` with a step argument, potentially affecting float precision.
⚠️ Breaking Changes
- The type `jax._src.literals.TypedNdArray` is now a subclass of `np.ndarray` instead of a duck type. Code relying on duck typing might need updates.
- `jax.numpy.arange` with `step` specified no longer generates the array on host, which may lead to less precise outputs for narrow-width floats (e.g., bfloat16). To restore previous behavior for narrow-width floats, use `jnp.array(np.arange(...))`.
Migration Steps
- If you relied on the previous duck-typing behavior of `jax._src.literals.TypedNdArray`, update your code.
- If you observe precision issues with narrow-width floats (like bfloat16) when using `jnp.arange(..., step=...)`, use `jnp.array(np.arange(...))` instead.