jax-v0.8.1
📦 jaxView on GitHub →
✨ 3 features🐛 1 fixes⚡ 5 deprecations🔧 10 symbols
Summary
JAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.
Migration Steps
- Replace any use of `jax.sharding.PmapSharding` with `jax.NamedSharding`.
- Replace calls to `jx.device_put_replicated` with `jax.device_put` and provide the appropriate sharding argument.
- Replace calls to `jax.device_put_sharded` with `jax.device_put` and provide the appropriate sharding argument.
- Explicitly specify `axis_types` when calling `jax.make_mesh` to avoid the upcoming default change and DeprecationWarning in JAX v0.9.0.
- Remove imports of `jax.cloud_tpu_init`; the functionality is now handled automatically by JAX.
✨ New Features
- `jax.jit` now supports the decorator factory pattern, allowing direct calls like `@jax.jit(static_argnames=['n'])`.
- `jax.lax.linalg.eigh` now accepts an `implementation` argument to select between QR, Jacobi, and QDWH implementations; the `EighImplementation` enum is exported from `jax.lax.linalg`.
- `jax.lax.linalg.svd` now implements an `algorithm` that uses the polar decomposition on CUDA GPUs (alias for existing TPU algorithm).
🐛 Bug Fixes
- Fixed a bug introduced in JAX 0.7.2 where `jax.lax.linalg.eigh` failed for large matrices on GPU (issue #33062).
🔧 Affected Symbols
jax.jitjax.lax.linalg.eighjax.lax.linalg.svdjax.sharding.PmapShardingjax.NamedShardingjx.device_put_replicatedjax.device_put_shardedjax.make_meshjax.cloud_tpu_initjax.lax.linalg.EighImplementation⚡ Deprecations
- `jax.sharding.PmapSharding` is deprecated; use `jax.NamedSharding` instead.
- `jx.device_put_replicated` is deprecated; use `jax.device_put` with appropriate sharding instead.
- `jax.device_put_sharded` is deprecated; use `jax.device_put` with appropriate sharding instead.
- Default `axis_types` of `jax.make_mesh` will change in JAX v0.9.0 to return `jax.sharding.AxisType.Explicit`; leaving `axis_types` unspecified will raise a `DeprecationWarning`.
- `jax.cloud_tpu_init` and its contents are deprecated; users should no longer import or use this module.