jax-v0.8.2
Breaking Changes📦 jaxView on GitHub →
⚠ 2 breaking⚡ 4 deprecations🔧 18 symbols
Summary
This release deprecates several core and interpreter symbols, removes `jax.experimental.si_vjp`, and changes `Tracer` inheritance, requiring code updates to use `jax.vjp` and the new `pcast` API.
⚠️ Breaking Changes
- jax.Tracer no longer inherits from jax.Array at runtime, which may break code that relies on subclass checks or isinstance behavior for Tracer objects.
- jax.experimental.si_vjp has been removed; code using this function must switch to jax.vjp.
Migration Steps
- Replace any use of `jax.experimental.si_vjp` with `jax.vjp`.
- Update code that imports or calls `jax.lax.pvary` to use `jax.lax.pcast(..., to='varying')`.
- Avoid passing complex numbers to `jax.numpy.arange` or refactor to handle real-valued inputs.
- Remove or replace usage of deprecated symbols from `jax.core` (e.g., `call_impl`, `get_aval`, etc.).
- Stop relying on any APIs from `jax.interpreters.pxla` as they are deprecated.
- If your code checks `issubclass(jax.Tracer, jax.Array)` at runtime, adjust the logic because `Tracer` no longer inherits from `Array`; rely on `isinstance(x, jax.Array)` which still works for traced arrays.
🔧 Affected Symbols
jax.lax.pvaryjax.lax.pcastjax.numpy.arangejax.core.call_impljax.core.get_avaljax.core.mapped_avaljax.core.subjaxprsjax.core.set_current_tracejax.core.take_current_tracejax.core.traverse_jaxpr_paramsjax.core.unmapped_avaljax.core.AbstractTokenjax.core.TraceTagjax.interpreters.pxlajax.Tracerjax.Arrayjax.experimental.si_vjpjax.vjp⚡ Deprecations
- `jax.lax.pvary` is deprecated; use `jax.lax.pcast(..., to='varying')` instead.
- Passing complex arguments to `jax.numpy.arange` now triggers a deprecation warning.
- The following symbols from `jax.core` are deprecated: `call_impl`, `get_aval`, `mapped_aval`, `subjaxprs`, `set_current_trace`, `take_current_trace`, `traverse_jaxpr_params`, `unmapped_aval`, `AbstractToken`, `TraceTag`.
- All symbols in `jax.interpreters.pxla` are deprecated.