Change8

jax-v0.8.1

📦 jaxView on GitHub →
3 features🐛 1 fixes5 deprecations🔧 10 symbols

Summary

JAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.

Migration Steps

  1. Replace any use of `jax.sharding.PmapSharding` with `jax.NamedSharding`.
  2. Replace calls to `jx.device_put_replicated` with `jax.device_put` and provide the appropriate sharding argument.
  3. Replace calls to `jax.device_put_sharded` with `jax.device_put` and provide the appropriate sharding argument.
  4. Explicitly specify `axis_types` when calling `jax.make_mesh` to avoid the upcoming default change and DeprecationWarning in JAX v0.9.0.
  5. Remove imports of `jax.cloud_tpu_init`; the functionality is now handled automatically by JAX.

✨ New Features

  • `jax.jit` now supports the decorator factory pattern, allowing direct calls like `@jax.jit(static_argnames=['n'])`.
  • `jax.lax.linalg.eigh` now accepts an `implementation` argument to select between QR, Jacobi, and QDWH implementations; the `EighImplementation` enum is exported from `jax.lax.linalg`.
  • `jax.lax.linalg.svd` now implements an `algorithm` that uses the polar decomposition on CUDA GPUs (alias for existing TPU algorithm).

🐛 Bug Fixes

  • Fixed a bug introduced in JAX 0.7.2 where `jax.lax.linalg.eigh` failed for large matrices on GPU (issue #33062).

🔧 Affected Symbols

jax.jitjax.lax.linalg.eighjax.lax.linalg.svdjax.sharding.PmapShardingjax.NamedShardingjx.device_put_replicatedjax.device_put_shardedjax.make_meshjax.cloud_tpu_initjax.lax.linalg.EighImplementation

⚡ Deprecations

  • `jax.sharding.PmapSharding` is deprecated; use `jax.NamedSharding` instead.
  • `jx.device_put_replicated` is deprecated; use `jax.device_put` with appropriate sharding instead.
  • `jax.device_put_sharded` is deprecated; use `jax.device_put` with appropriate sharding instead.
  • Default `axis_types` of `jax.make_mesh` will change in JAX v0.9.0 to return `jax.sharding.AxisType.Explicit`; leaving `axis_types` unspecified will raise a `DeprecationWarning`.
  • `jax.cloud_tpu_init` and its contents are deprecated; users should no longer import or use this module.