Change8

jax-v0.9.2

Breaking Changes
📦 jaxView on GitHub →
2 breaking🔧 2 symbols

Summary

JAX 0.9.2 updates the internal type structure of `TypedNdArray` and changes the execution context for `jnp.arange` with a step argument, potentially affecting float precision.

⚠️ Breaking Changes

  • The type `jax._src.literals.TypedNdArray` is now a subclass of `np.ndarray` instead of a duck type. Code relying on duck typing might need updates.
  • `jax.numpy.arange` with `step` specified no longer generates the array on host, which may lead to less precise outputs for narrow-width floats (e.g., bfloat16). To restore previous behavior for narrow-width floats, use `jnp.array(np.arange(...))`.

Migration Steps

  1. If you relied on the previous duck-typing behavior of `jax._src.literals.TypedNdArray`, update your code.
  2. If you observe precision issues with narrow-width floats (like bfloat16) when using `jnp.arange(..., step=...)`, use `jnp.array(np.arange(...))` instead.

Affected Symbols