jax-v0.5.1
📦 jaxView on GitHub →
✨ 3 features🐛 2 fixes⚡ 1 deprecations🔧 12 symbols
Summary
This release adds experimental custom DCE support, new low‑level reduction ops, column‑pivoting QR, updates CPU collective defaults, removes the libtpu‑nightly dependency, deprecates internal functions requiring debug info, and includes TPU runtime and compilation cache fixes.
Migration Steps
- If you use `linear_util.wrap_init` or construct `core.Jaxpr`, pass a non‑empty `core.DebugInfo` kwarg.
- Remove the `libtpu-nightly` package if present; JAX now uses `libtpu`.
- For improved TPU v5e startup/shutdown, enable transparent hugepages: `sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'`.
✨ New Features
- Added experimental `jax.experimental.custom_dce.custom_dce` decorator to customize opaque functions under JAX dead code elimination.
- Added low‑level reduction APIs in `jax.lax`: `reduce_sum`, `reduce_prod`, `reduce_max`, `reduce_min`, `reduce_and`, `reduce_or`, and `reduce_xor`.
- `jax.lax.linalg.qr` and `jax.scipy.linalg.qr` now support column‑pivoting on CPU and GPU.
🐛 Bug Fixes
- TPU runtime startup and shutdown time significantly improved on TPU v5e and newer; enable transparent hugepages if not already set.
- Persistent compilation cache no longer writes access‑time files when `JAX_COMPILATION_CACHE_MAX_SIZE` is unset or set to -1, improving performance on large‑scale network storage.
🔧 Affected Symbols
jax.experimental.custom_dce.custom_dcejax.lax.reduce_sumjax.lax.reduce_prodjax.lax.reduce_maxjax.lax.reduce_minjax.lax.reduce_andjax.lax.reduce_orjax.lax.reduce_xorjax.lax.linalg.qrjax.scipy.linalg.qrjax.extend.linear_util.wrap_initcore.Jaxpr⚡ Deprecations
- The internal function `linear_util.wrap_init` and the constructor `core.Jaxpr` now require a non‑empty `core.DebugInfo` keyword argument; using them without this argument will raise a DeprecationWarning and may become an error in future releases.