Change8

jax-v0.10.0

Breaking Changes
📦 jaxView on GitHub →
9 breaking6 features🐛 4 fixes1 deprecations🔧 27 symbols

Summary

This release introduces new features like PyTorch-compatible cubic resizing and improved LAPACK parallelization, alongside significant breaking changes related to the removal of the C++ pmap infrastructure and updates to array stacking functions.

⚠️ Breaking Changes

  • PartitionSpec objects no longer report themselves to be equal to tuples. Convert tuples to `PartitionSpec` objects before testing equality.
  • The .vma property has been removed from `jax.core.ShapedArray`. Use .manual_axis_type.varying instead.
  • JAX CPU devices now report their names as `cpu:0`, `cpu:1`, etc. instead of `TFRT_CPU_0`, `TFRT_CPU_1`.
  • The config state `jax_pmap_shmap_merge` has been removed. `jax.pmap` will now always use the new implementation that wraps `jax.jit(jax.shard_map)`. Please see https://docs.jax.dev/en/latest/migrate_pmap.html for more information.
  • `jax.device_put_sharded` and `jax.device_put_replicated` have been removed from the public API and now raise an `AttributeError` when accessed. Please see https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for drop-in replacements.
  • The C++ pmap infrastructure has been removed. The following public APIs are no longer available: `jax.sharding.PmapSharding`, and several APIs from `jaxlib.xla_extension` and `jax.interpreters.pxla` (including `PmapFunction`, `pmap`, `NoSharding`, `Chunked`, `Unstacked`, `ShardedAxis`, `Replicated`, `ShardingSpec`, `MapTracer`, `PmapExecutable`, `parallel_callable`, `shard_args`, `xla_pmap_p`, `spec_to_indices`).
  • The deprecated keyword arguments `a`, `a_min`, and `a_max` to `jax.numpy.clip` have been removed.
  • Functions `jax.numpy.hstack`, `jax.numpy.vstack`, `jax.numpy.dstack`, `jax.numpy.column_stack`, `jax.numpy.atleast_1d`, `jax.numpy.atleast_2d`, and `jax.numpy.atleast_3d` no longer accept non-`ArrayLike` inputs.
  • jax.scipy.stats.rankdata now returns floating point values in all cases, following a similar change in the SciPy 1.18 release.

Migration Steps

  1. Convert tuples to `PartitionSpec` objects before testing equality with `PartitionSpec` objects.
  2. Replace usage of the removed `.vma` property on `jax.core.ShapedArray` with `.manual_axis_type.varying`.
  3. Review migration guide at https://docs.jax.dev/en/latest/migrate_pmap.html as the config state `jax_pmap_shmap_merge` is removed and `jax.pmap` behavior has changed.
  4. Find drop-in replacements for the removed `jax.device_put_sharded` and `jax.device_put_replicated` by consulting https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements.
  5. Update code that relies on the removed C++ pmap infrastructure APIs.
  6. Ensure inputs to `jax.numpy.clip` are arrays, as deprecated keyword arguments `a`, `a_min`, and `a_max` are removed.
  7. Ensure inputs to `jax.numpy.hstack`, `jax.numpy.vstack`, `jax.numpy.dstack`, `jax.numpy.column_stack`, `jax.numpy.atleast_1d`, `jax.numpy.atleast_2d`, and `jax.numpy.atleast_3d` are `ArrayLike`.
  8. Replace usage of `vma` parameter/property on `jax.ShapeDtypeStruct` with `manual_axis_type: jax.sharding.ManualAxisType` and `.manual_axis_type.varying` respectively.
  9. Update internal code using deprecated `jax.core` APIs; consider moving to `jax.extend.core` where applicable.

✨ New Features

  • Added `ResizeMethod.CUBIC_PYTORCH` to jax.image.resize to match PyTorch's bicubic resize (#15768).
  • Support differentiation of jax.lax.linalg.qr for wide matrices and when `full_matrices` is `True`.
  • LAPACK operations are now parallelized along the batch dimension on CPU.
  • Added `perturb_singular` argument to `jax.lax.linalg.tridiagonal_solve` to handle singular matrices by perturbing near-zero pivots in the LU decomposition.
  • jax.scipy.linalg.eigh_tridiagonal now supports computing eigenvectors on CPU and GPU.
  • Added the jax.numpy.ndarray.byteswap method.

🐛 Bug Fixes

  • Fixed a bug that led to differing output between CPU and GPU for non-symmetric multidimensional IRFFTs (#29325).
  • Fixed an error when tiny matrices were passed to `jax.lax.linalg.tridiagonal_solve` on GPU (#32487).
  • Fixed a bug in `jax.scipy.fft.dctn` and `idctn` where `axes=None` incorrectly defaulted to all axes when `s` was specified, instead of the last `len(s)` axes to match SciPy behavior (#29426).
  • Fixed a bug where calling `jax.distributed.initialize()` on a GCE TPU Managed Instance Group raised an `IndexError` (#36593) by correctly parsing the metadata server format.

Affected Symbols

⚡ Deprecations

  • A number of internal APIs in `jax.core` have been newly deprecated and some have been moved to `jax.extend.core`. These include `CallPrimitive`, `DebugInfo`, `DropVar`, `Effect`, `Effects`, `InconclusiveDimensionOperation`, `JaxprTypeError`, `check_jaxpr`, `concrete_or_error`, `find_top_trace`, `gensym`, `get_opaque_trace_state`, `jaxprs_in_params`, `new_jaxpr_eqn`, `no_effects`, `nonempty_axis_env_DO_NOT_USE`, `primal_dtype_to_tangent_dtype`, `unsafe_am_i_under_a_jit_DO_NOT_USE`, `unsafe_am_i_under_a_vmap_DO_NOT_USE`, `unsafe_get_axis_names_DO_NOT_USE`, `valid_jaxtype`, `JaxprPpContext`, `JaxprPpSettings`, `OutputType`, `abstract_token`, `aval_mapping_handlers`, `call`, `concretization_function_error`, `custom_typechecks`, `is_concrete`, `is_constant_dim`, `is_constant_shape`, `literalable_types`, `no_axis_name`, `pytype_aval_mappings`, and `trace_ctx`.