Recent Releases of jax

jax - JAX v0.7.1

  • New features

    • JAX now ships Python 3.14 and 3.14t wheels.
    • JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only offered free-threading builds on Linux.
  • Changes

    • Exposed jax.set_mesh which acts as a global setter and a context manager. Removed jax.sharding.use_mesh in favor of jax.set_mesh.
    • JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain supported.
    • jax.lax.dot now implements the general dot product via the optional dimension_numbers argument.
  • Deprecations:

    • jax.lax.zeros_like_array is deprecated. Please use jax.numpy.zeros_like instead.
    • Attempting to import jax.experimental.host_callback now results in a DeprecationWarning, and will result in an ImportError starting in JAX v0.8.0. Its APIs have raised NotImplementedError since JAX version 0.4.35.
    • In jax.lax.dot, passing the precision and preferred_element_type arguments by position is deprecated. Pass them by explicit keyword instead.
    • Several dozen internal APIs have been deprecated from jax.interpreters.ad, jax.interpreters.batching, and jax.interpreters.partial_eval; they are used rarely if ever outside JAX itself, and most are deprecated without any public replacement.

- Python
Published by jakeharmon8 10 months ago

jax - JAX v0.7.0

  • New features:

    • Added jax.P which is an alias for jax.sharding.PartitionSpec.
    • Added jax.tree.reduce_associative.
  • Breaking changes:

    • JAX is migrating from GSPMD to Shardy by default. See the migration guide for more information.
    • JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval). See migration guide for more information.
    • jax.stages.OutInfo has been replaced with jax.ShapeDtypeStruct.
    • jax.jit now requires fun to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in an error starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.
    • The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026.
    • Layout API renames:
    • Layout, .layout, .input_layouts and .output_layouts have been renamed to Format, .format, .input_formats and .output_formats
    • DeviceLocalLayout, .device_local_layout have been renamed to Layout and .layout
    • jax.experimental.shard module has been deleted and all the APIs have been moved to the jax.sharding endpoint. So use jax.sharding.reshard, jax.sharding.auto_axes and jax.sharding.explicit_axes instead of their experimental endpoints.
    • lax.infeed and lax.outfeed were removed, after being deprecated in JAX 0.6. The transfer_to_infeed and transfer_from_outfeed methods were also removed the Device objects.
    • The jax.extend.core.primitives.pjit_p primitive has been renamed to jit_p, and its name attribute has changed from "pjit" to "jit". This affects the string representations of jaxprs. The same primitive is no longer exported from the jax.experimental.pjit module.
    • The (undocumented) function jax.extend.backend.add_clear_backends_callback has been removed. Users should use jax.extend.backend.register_backend_cache instead.
  • Deprecations:

    • {obj}jax.dlpack.SUPPORTED_DTYPES is deprecated; please use the new jax.dlpack.is_supported_dtype function.
    • jax.scipy.special.sph_harm has been deprecated following a similar deprecation in SciPy; use jax.scipy.special.sph_harm_y instead.
    • From {mod}jax.interpreters.xla, the previously deprecated symbols abstractify and pytype_aval_mappings have been removed.
    • jax.interpreters.xla.canonicalize_dtype is deprecated. For canonicalizing dtypes, prefer jax.dtypes.canonicalize_dtype. For checking whether an object is a valid jax input, prefer jax.core.valid_jaxtype.
    • From {mod}jax.core, the previously deprecated symbols AxisName, ConcretizationTypeError, axis_frame, call_p, closed_call_p, get_type, trace_state_clean, typematch, and typecheck have been removed.
    • From {mod}jax.lib.xla_client, the previously deprecated symbols DeviceAssignment, get_topology_for_devices, and mlir_api_version have been removed.
    • jax.extend.ffi was removed after being deprecated in v0.5.0. Use {mod}jax.ffi instead.
    • jax.lib.xla_bridge.get_compile_options is deprecated, and replaced by jax.extend.backend.get_compile_options.

- Python
Published by MichaelHudgins 11 months ago

jax - JAX v0.6.2

  • New features:

    • Added jax.tree.broadcast which implements a pytree prefix broadcasting helper.
  • Changes

    • The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.

- Python
Published by yashk2810 12 months ago

jax - JAX v0.6.1

  • New features:

    • Added jax.lax.axis_size which returns the size of the mapped axis given its name.
  • Changes

    • Additional checking for the versions of CUDA package dependencies was reenabled, having been accidentally disabled in a previous release.
    • JAX nightly packages are now published to artifact registry. To install these packages, see the JAX installation guide.
    • jax.sharding.PartitionSpec no longer inherits from a tuple.
    • jax.ShapeDtypeStruct is immutable now. Please use .update method to update your ShapeDtypeStruct instead of doing in-place updates.
  • Deprecations

    • jax.custom_derivatives.custom_jvp_call_jaxpr_p is deprecated, and will be removed in JAX v0.7.0.

- Python
Published by hawkinsp about 1 year ago

jax - JAX v0.6.0

  • Breaking changes

    • jax.numpy.array no longer accepts None. This behavior was deprecated since November 2023 and is now removed.
    • Removed the config.jax_data_dependent_tracing_fallback config option, which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery.
    • Removed the config.jax_eager_pmap config option.
    • Disallow the calling of lower and trace AOT APIs on the result of jax.jit if there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to apply jax.jit last among the wrappers, and similarly for jax.pmap. See #27873.
    • The cuda12_pip extra for jax has been removed; use pip install jax[cuda12] instead.
  • Changes

    • The minimum CuDNN version is v9.8.
    • JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.
    • JAX package extras are now updated to use dash instead of underscore to align with PEP 685. For instance, if you were previously using pip install jax[cuda12_local] to install JAX, run pip install jax[cuda12-local] instead.
    • jax.jit now requires fun to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
  • Deprecations

    • jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten instead.
    • Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call.
    • All APIs in jax.lib.xla_extension are now deprecated.
    • jax.interpreters.mlir.hlo and jax.interpreters.mlir.func_dialect, which were accidental exports, have been removed. If needed, they are available from jax.extend.mlir.
    • jax.interpreters.mlir.custom_call is deprecated. The APIs provided by jax.ffi should be used instead.
    • The deprecated use of jax.ffi.ffi_call with inline arguments is no longer supported. jax.ffi.ffi_call now unconditionally returns a callable.
    • The following exports in jax.lib.xla_client are deprecated: get_topology_for_devices, heap_profile, mlir_api_version, Client, CompileOptions, DeviceAssignment, Frame, HloSharding, OpSharding, Traceback.
    • The following internal APIs in jax.util are deprecated: HashableFunction, as_hashable_function, cache, safe_map, safe_zip, split_dict, split_list, split_list_checked, split_merge, subvals, toposort, unzip2, wrap_name, and wraps.
    • jax.dlpack.to_dlpack has been deprecated. You can usually pass a JAX Array directly to the from_dlpack function of another framework. If you need the functionality of to_dlpack, use the __dlpack__ attribute of an array.
    • jax.lax.infeed, jax.lax.infeed_p, jax.lax.outfeed, and jax.lax.outfeed_p are deprecated and will be removed in JAX v0.7.0.
    • Several previously-deprecated APIs have been removed, including:
    • From jax.lib.xla_client: ArrayImpl, FftType, PaddingType, PrimitiveType, XlaBuilder, dtype_to_etype, ops, register_custom_call_target, shape_from_pyval, Shape, XlaComputation.
    • From jax.lib.xla_extension: ArrayImpl, XlaRuntimeError.
    • From jax: jax.treedef_is_leaf, jax.tree_flatten, jax.tree_map, jax.tree_leaves, jax.tree_structure, jax.tree_transpose, and jax.tree_unflatten. Replacements can be found in jax.tree or jax.tree_util.
    • From jax.core: AxisSize, ClosedJaxpr, EvalTrace, InDBIdx, InputType, Jaxpr, JaxprEqn, Literal, MapPrimitive, OpaqueTraceState, OutDBIdx, Primitive, Token, TRACER_LEAK_DEBUGGER_WARNING, Var, concrete_aval, dedup_referents, escaped_tracer_error, extend_axis_env_nd, full_lower, get_referent, jaxpr_as_fun, join_effects, lattice_join, leaked_tracer_error, maybe_find_leaked_tracers, raise_to_shaped, raise_to_shaped_mappings, reset_trace_state, str_eqn_compact, substitute_vars_in_output_ty, typecompat, and used_axis_names_jaxpr. Most have no public replacement, though a few are available at jax.extend.core.
    • The vectorized argument to jax.pure_callback and jax.ffi.ffi_call. Use the vmap_method parameter instead.

- Python
Published by hawkinsp about 1 year ago

jax - JAX v0.5.3

  • New Features

    • Added a allow_negative_indices option to jax.lax.dynamic_slice, jax.lax.dynamic_update_slice and related functions. The default is true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size.
    • Added a replace option to jax.random.categorical to enable sampling without replacement.

- Python
Published by hawkinsp about 1 year ago

jax - JAX v0.5.2

Patch release of 0.5.1

  • Bug fixes
    • Fixes TPU metric logging and tpu-info, which was broken in 0.5.1

- Python
Published by skye over 1 year ago

jax - JAX v0.5.1

  • New Features

    • Added an experimental jax.experimental.custom_dce.custom_dce decorator to support customizing the behavior of opaque functions under JAX-level dead code elimination (DCE). See #25956 for more details.
    • Added low-level reduction APIs in {mod}jax.lax: jax.lax.reduce_sum, jax.lax.reduce_prod, jax.lax.reduce_max, jax.lax.reduce_min, jax.lax.reduce_and, jax.lax.reduce_or, and jax.lax.reduce_xor.
    • jax.lax.linalg.qr, and jax.scipy.linalg.qr, now support column-pivoting on CPU and GPU. See #20282 and #25955 for more details.
  • Changes

    • JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES now work as env vars. Before they could only be specified via jax.config or flags.
    • JAX_CPU_COLLECTIVES_IMPLEMENTATION now defaults to 'gloo', meaning multi-process CPU communication works out-of-the-box.
    • The jax[tpu] TPU extra no longer depends on the libtpu-nightly package. This package may safely be removed if it is present on your machine; JAX now uses libtpu instead.
  • Deprecations

    • The internal function linear_util.wrap_init and the constructor core.Jaxpr now must take a non-empty core.DebugInfo kwarg. For a limited time, a DeprecationWarning is printed if jax.extend.linear_util.wrap_init is used without debugging info. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail.
  • Bug fixes

    • TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer (from around 17s to around 8s). If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'). We hope to improve this further in future releases.
    • Persistent compilation cache no longer writes access time file if JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU eviction policy isn't enabled. This should improve performance when using the cache with large-scale network storage.

- Python
Published by hawkinsp over 1 year ago

jax - JAX v0.5.0

As of this release, JAX now uses effort-based versioning. Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this.

  • Breaking changes

    • Enable jax_threefry_partitionable by default (see the update note).
    • This release drops support for Mac x86 wheels. Mac ARM of course remains supported. For a recent discussion, see https://github.com/jax-ml/jax/discussions/22936.

    Two key factors motivated this decision: * The Mac x86 build (only) has a number of test failures and crashes. We would prefer to ship no release than a broken release. * Mac x86 hardware is end-of-life and cannot be easily obtained for developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to.

    We are open to readding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again.

  • Changes:

    • The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025.
    • The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum supported version until June 2025.
    • jax.numpy.einsum now defaults to optimize='auto' rather than optimize='optimal'. This avoids exponentially-scaling trace-time in the case of many arguments (#25214).
    • jax.numpy.linalg.solve no longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, use solve(a, b[..., None]).squeeze(-1).
  • New Features

    • jax.numpy.fft.fftn, jax.numpy.fft.rfftn, jax.numpy.fft.ifftn, and jax.numpy.fft.irfftn now support transforms in more than 3 dimensions, which was previously the limit. See #25606 for more details.
    • Support added for user defined state in the FFI via the new jax.ffi.register_ffi_type_id function.
    • The AOT lowering .as_text() method now supports the debug_info option to include debugging information, e.g., source location, in the output.
  • Deprecations

    • From jax.interpreters.xla, abstractify and pytype_aval_mappings are now deprecated, having been replaced by symbols of the same name in jax.core.
    • jax.scipy.special.lpmn and jax.scipy.special.lpmn_values are deprecated, following their deprecation in SciPy v1.15.0. There are no plans to replace these deprecated functions with new APIs.
    • The jax.extend.ffi submodule was moved to jax.ffi, and the previous import path is deprecated.
  • Deletions

    • jax_enable_memories flag has been deleted and the behavior of that flag is on by default.
    • From jax.lib.xla_client, the previously-deprecated Device and XlaRuntimeError symbols have been removed; instead use jax.Device and jax.errors.JaxRuntimeError respectively.
    • The jax.experimental.array_api module has been removed after being deprecated in JAX v0.4.32. Since that release, jax.numpy supports the array API directly.

- Python
Published by hawkinsp over 1 year ago

jax - JAX v0.4.38

  • Changes:

    • jax.tree.flatten_with_path and jax.tree.map_with_path are added as shortcuts of the corresponding tree_util functions.
  • Deprecations

    • a number of APIs in the internal jax.core namespace have been deprecated. Most were no-ops, were little-used, or can be replaced by APIs of the same name in jax.extend.core; see the documentation for {mod}jax.extend for information on the compatibility guarantees of these semi-public extensions.
    • Several previously-deprecated APIs have been removed, including:
    • from jax.core: check_eqn, check_type, check_valid_jaxtype, and non_negative_dim.
    • from jax.lib.xla_bridge: xla_client and default_backend.
    • from jax.lib.xla_client: _xla and bfloat16.
    • from jax.numpy: round_.
  • New Features

    • jax.export.export can be used for device-polymorphic export with shardings constructed with {func}jax.sharding.AbstractMesh. See the jax.export documentation.
    • Added jax.lax.split. This is a primitive version of jax.numpy.split, added because it yields a more compact transpose during automatic differentiation.

- Python
Published by hawkinsp over 1 year ago

jax - JAX v0.4.37

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug where jit would error if an argument was named f (#25329).
    • Fix a bug that will throw index out of range error in jax.lax.while_loop if the user registers pytree node class with different aux data for the flatten and flattenwithpath.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.

- Python
Published by hawkinsp over 1 year ago

jax - JAX v0.4.36

  • Breaking Changes

    • This release lands "stackless", an internal change to JAX's tracing machinery. We made trace dispatch purely a function of context rather than a function of both context and data. This let us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind, and so on. The change should only affect users that use JAX internals.

    If you do use JAX internals then you may need to update your code (see https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the config.jax_data_dependent_tracing_fallback flag as a workaround, and if you need help updating your code then please file a bug. * jax.experimental.jax2tf.convert with native_serialization=False or with enable_xla=False have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. jax2tf with native serialization will still be supported. * In jax.interpreters.xla, the xb, xc, and xe symbols have been removed after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge, xc = jax.lib.xla_client, and xe = jax.lib.xla_extension. * The deprecated module jax.experimental.export has been removed. It was replaced by jax.export in JAX v0.4.30. See the migration guide for information on migrating to the new API. * The initial argument to jax.nn.softmax and jax.nn.log_softmax has been removed, after being deprecated in v0.4.27. * Calling np.asarray on typed PRNG keys (i.e. keys produced by jax.random.key) now raises an error. Previously, this returned a scalar object array. * The following deprecated methods and functions in jax.export have been removed: * jax.export.DisabledSafetyCheck.shape_assertions: it had no effect already. * jax.export.Exported.lowering_platforms: use platforms. * jax.export.Exported.mlir_module_serialization_version: use calling_convention_version. * jax.export.Exported.uses_shape_polymorphism: use uses_global_constants. * the lowering_platforms kwarg for jax.export.export: use platforms instead. * The kwargs symbolic_scope and symbolic_constraints from jax.export.symbolic_args_specs have been removed. They were deprecated in June 2024. Use scope and constraints instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a TypeError. * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and replaces previous build.py usage. Run python build/build.py --help for more details. Brief overview of the new subcommand options: * build: Builds JAX wheel packages. For e.g., python build/build.py build --wheels=jaxlib,jax-cuda-pjrt * requirements_update: Updates requirementslock.txt files. * jax.scipy.linalg.toeplitz now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call jax.numpy.ravel on the function inputs. * jax.scipy.special.gamma and jax.scipy.special.gammasgn now return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827. * `jax.clearbackendswas removed after being deprecated in v0.4.26. * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use thedisabled_checks` parameter. See more details in the documentation.

  • New Features

    • jax.jit got a new compiler_options: dict[str, Any] argument, for passing compilation options to XLA. For the moment it's undocumented and may be in flux.
    • jax.tree_util.register_dataclass now allows metadata fields to be declared inline via dataclasses.field. See the function documentation for examples.
    • Added jax.numpy.put_along_axis.
    • jax.lax.linalg.eig and the related jax.numpy functions (jax.numpy.linalg.eig and jax.numpy.linalg.eigvals) are now supported on GPU. See #24663 for more details.
    • Added two new configuration flags, jax_exec_time_optimization_effort and jax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
  • Bug fixes

    • Fixed a bug where the GPU implementations of LU and QR decomposition would result in an indexing overflow for batch sizes close to int32 max. See #24843 for more details.
  • Deprecations

    • jax.lib.xla_extension.ArrayImpl and jax.lib.xla_client.ArrayImpl are deprecated; use jax.Array instead.
    • jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.

- Python
Published by hawkinsp over 1 year ago

jax - JAX v0.4.35

  • Breaking Changes

    • jax.numpy.isscalar now returns True for any array-like object with zero dimensions. Previously it only returned True for zero-dimensional array-like objects with a weak dtype.
    • jax.experimental.host_callback has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See #20385 for a discussion of alternatives.
  • Changes:

    • jax.lax.FftType was introduced as a public name for the enum of FFT operations. The semi-public API jax.lib.xla_client.FftType has been deprecated.
    • TPU: JAX now installs TPU support from the libtpu package rather than libtpu-nightly. For the next few releases JAX will pin an empty version of libtpu-nightly as well as libtpu to ease the transition; that dependency will be removed in Q1 2025.
  • Deprecations:

    • The semi-public API jax.lib.xla_client.PaddingType has been deprecated. No JAX APIs consume this type, so there is no replacement.
    • The default behavior of jax.pure_callback and jax.extend.ffi.ffi_call under vmap has been deprecated and so has the vectorized parameter to those functions. The vmap_method parameter should be used instead for better defined behavior. See the discussion in #23881 for more details.
    • The semi-public API jax.lib.xla_client.register_custom_call_target has been deprecated. Use the JAX FFI instead.
    • The semi-public APIs jax.lib.xla_client.dtype_to_etype, jax.lib.xla_client.ops, jax.lib.xla_client.shape_from_pyval, jax.lib.xla_client.PrimitiveType, jax.lib.xla_client.Shape, jax.lib.xla_client.XlaBuilder, and jax.lib.xla_client.XlaComputation have been deprecated. Use StableHLO instead.

- Python
Published by hawkinsp over 1 year ago

jax - JAX v0.4.34

  • New Functionality

    • This release includes wheels for Python 3.13. Free-threading mode is not yet supported.
    • jax.errors.JaxRuntimeError has been added as a public alias for the formerly private XlaRuntimeError type.
  • Breaking changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
    • array[0] on a pmap result now introduces a reshape (use array[0:1] instead).
    • The per-shard shape (accessable via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit.
    • jax.experimental.host_callback has been deprecated since March 2024, with JAX version 0.4.26. Now we set the default value of the --jax_host_callback_legacy configuration value to True, which means that if your code uses jax.experimental.host_callback APIs, those API calls will be implemented in terms of the new jax.experimental.io_callback API. If this breaks your code, for a very limited time, you can set the --jax_host_callback_legacy to True. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See #20385 for a discussion.
  • Deprecations

    • In jax.numpy.trim_zeros, non-arraylike arguments or arraylike arguments with ndim != 1 are now deprecated, and in the future will result in an error.
    • Internal pretty-printing tools jax.core.pp_* have been removed, after being deprecated in JAX v0.4.30.
    • jax.lib.xla_client.Device is deprecated; use jax.Device instead.
    • jax.lib.xla_client.XlaRuntimeError has been deprecated. Use jax.errors.JaxRuntimeError instead.
  • Deletion:

    • jax.xla_computation is deleted. It has been 3 months since its deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality as jax.xla_computation.
    • jax.xla_computation(fn)(*args, **kwargs) can be replaced with jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
    • You can also use .out_info property of jax.stages.Lowered to get the output information (like tree structure, shape and dtype).
    • For cross-backend lowering, you can replace jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
    • jax.ShapeDtypeStruct no longer accepts the named_shape argument. The argument was only used by xmap which was removed in 0.4.31.
    • jax.tree.map(f, None, non-None), which previously emitted a DeprecationWarning, now raises an error. None is only a tree-prefix of itself. To preserve the current behavior, you can ask jax.tree.map to treat None as a leaf value by writing: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
    • jax.sharding.XLACompatibleSharding has been removed. Please use jax.sharding.Sharding.
  • Bug fixes

    • Fixed a bug where jax.numpy.cumsum would produce incorrect outputs if a non-boolean input was provided and dtype=bool was specified.
    • Edit implementation of jax.numpy.ldexp to get correct gradient.

- Python
Published by hawkinsp over 1 year ago

jax - JAX release v0.4.33

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that release.

A TPU-only data corruption bug was found in the version of libtpu pinned by JAX 0.4.32, which manifested only if multiple TPU slices were present in the same job, for example, if training on multiple v5e slices.

This release fixes that issue by pinning a fixed version of libtpu-nightly.

This release also fixes an inaccurate result for F64 tanh on CPU (#23590).

- Python
Published by hawkinsp almost 2 years ago

jax - JAX release v0.4.32

WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job

- Python
Published by hawkinsp almost 2 years ago

jax - Jaxlib release v0.4.32

WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job

- Python
Published by hawkinsp almost 2 years ago

jax - JAX release v0.4.31

- Python
Published by yashk2810 almost 2 years ago

jax - Jaxlib release v0.4.31

- Python
Published by yashk2810 almost 2 years ago

jax - Jaxlib release v0.4.30

- Python
Published by yashk2810 almost 2 years ago

jax - Jax release v0.4.30

- Python
Published by yashk2810 almost 2 years ago

jax - JAX v0.4.29

  • Changes

    • We anticipate that this will be the last release of JAX and jaxlib supporting a monolithic CUDA jaxlib. Future releases will use the CUDA plugin jaxlib (e.g. pip install jax[cuda12]).
    • JAX now requires ml_dtypes version 0.4.0 or newer.
    • Removed backwards-compatibility support for old usage of the jax.experimental.export API. It is not possible anymore to use from jax.experimental.export import export, and instead you should use from jax.experimental import export. The removed functionality has been deprecated since 0.4.24.
  • Deprecations

    • jax.sharding.XLACompatibleSharding is deprecated. Please use jax.sharding.Sharding.
    • jax.experimental.Exported.in_shardings has been renamed as jax.experimental.Exported.in_shardings_hlo. Same for out_shardings. The old names will be removed after 3 months.
    • Removed a number of previously-deprecated APIs:
    • from {mod}jax.core: non_negative_dim, DimSize, Shape
    • from {mod}jax.lax: tie_in
    • from {mod}jax.nn: normalize
    • from {mod}jax.interpreters.xla: backend_specific_translations, translations, register_translation, xla_destructure, TranslationRule, TranslationContext, XlaOp.
    • The tol argument of {func}jax.numpy.linalg.matrix_rank is being deprecated and will soon be removed. Use rtol instead.
    • The rcond argument of {func}jax.numpy.linalg.pinv is being deprecated and will soon be removed. Use rtol instead.
    • The deprecated jax.config submodule has been removed. To configure JAX use import jax and then reference the config object via jax.config.
    • {mod}jax.random APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}jax.vmap in such cases.
  • New Functionality

    • Added {func}jax.experimental.Exported.in_shardings_jax to construct shardings that can be used with the JAX APIs from the HloShardings that are stored in the Exported objects.

- Python
Published by hawkinsp about 2 years ago

jax - Jaxlib release v0.4.29

  • Bug fixes

    • Fixed a bug where XLA sharded some concatenation operations incorrectly, which manifested as an incorrect output for cumulative reductions (#21403).
    • Fixed a bug where XLA:CPU miscompiled certain matmul fusions (https://github.com/openxla/xla/pull/13301).
    • Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).
  • Deprecations

    • jax.tree.map(f, None, non-None) now emits a DeprecationWarning, and will raise an error in a future version of jax. None is only a tree-prefix of itself. To preserve the current behavior, you can ask jax.tree.map to treat None as a leaf value by writing: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).

- Python
Published by hawkinsp about 2 years ago

jax - JAX v0.4.28

  • Bug fixes

    • Reverted a change to make_jaxpr that was breaking Equinox (#21116).
  • Deprecations & removals

    • The kind argument to jax.numpy.sort and jax.numpy.argsort is now removed. Use stable=True or stable=False instead.
    • Removed get_compute_capability from the jax.experimental.pallas.gpu module. Use the compute_capability attribute of a GPU device, returned by jax.devices or jax.local_devices, instead.
  • Changes

    • The minimum jaxlib version of this release is 0.4.27.

- Python
Published by hawkinsp about 2 years ago

jax - jaxlib v0.4.28

  • Bug fixes

    • Fixes a memory corruption bug in the type name of Array and JIT Python objects in Python 3.10 or earlier.
    • Fixed a warning '+ptx84' is not a recognized feature for this target under CUDA 12.4.
    • Fixed a slow compilation problem on CPU.
  • Changes

    • The Windows build is now built with Clang instead of MSVC.

- Python
Published by hawkinsp about 2 years ago

jax - Jax release v0.4.27

- Python
Published by yashk2810 about 2 years ago

jax - Jaxlib release v0.4.27

- Python
Published by yashk2810 about 2 years ago

jax - Jax release v0.4.26

- Python
Published by yashk2810 about 2 years ago

jax - Jaxlib release v0.4.26

- Python
Published by yashk2810 about 2 years ago

jax - JAX release v0.4.25

- Python
Published by yashk2810 over 2 years ago

jax - Jaxlib release v0.4.25

- Python
Published by yashk2810 over 2 years ago

jax - JAX release v0.4.24

JAX release v0.4.24

- Python
Published by skye over 2 years ago

jax - jaxlib release v0.4.24

jaxlib release v0.4.24

- Python
Published by skye over 2 years ago

jax - Jax release v0.4.23

- Python
Published by yashk2810 over 2 years ago

jax - Jaxlib release v0.4.23

- Python
Published by yashk2810 over 2 years ago

jax - Jaxlib release v0.4.21

- Python
Published by yashk2810 over 2 years ago

jax - Jax release v0.4.21

- Python
Published by yashk2810 over 2 years ago

jax - JAX release v0.4.20

- Python
Published by skye over 2 years ago

jax - jaxlib release v0.4.20

- Python
Published by skye over 2 years ago

jax - jaxlib release v0.4.19

- Python
Published by yashk2810 over 2 years ago

jax - Jax release v0.4.19

- Python
Published by yashk2810 over 2 years ago

jax - JAX release v0.4.18

- Python
Published by skye over 2 years ago

jax - jaxlib-v0.4.18

- Python
Published by skye over 2 years ago

jax - JAX release v0.4.17

- Python
Published by skye over 2 years ago

jax - jaxlib release v0.4.17

- Python
Published by skye over 2 years ago

jax - Jax release v0.4.16

- Python
Published by yashk2810 over 2 years ago

jax - Jaxlib release v0.4.16

- Python
Published by yashk2810 over 2 years ago

jax - JAX release v0.4.14

- Python
Published by skye almost 3 years ago

jax - jaxlib release v0.4.14

- Python
Published by skye almost 3 years ago

jax - JAX release v0.4.13

NOTE: This is the last JAX release that will include Python 3.8 support

  • Changes

    • jax.jit now allows None to be passed to in_shardings and out_shardings. The semantics are as follows:
      • For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
      • For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
    • jax.experimental.pjit.pjit also allows None to be passed to in_shardings and out_shardings. The semantics are as follows:
    • If the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants.
      • For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
      • For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
    • If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
    • Executable.cost_analysis() works on Cloud TPU
    • Added a warning if a non-allowlisted jaxlib plugin is in use.
    • Added jax.tree_util.tree_leaves_with_path.
  • Bug fixes

    • Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel is named cudnn89 instead of cudnn88.
  • Deprecations

    • The native_serialization_strict_checks parameter to {func}jax.experimental.jax2tf.convert is deprecated in favor of the new native_serializaation_disabled_checks ({jax-issue}#16347).

- Python
Published by skye almost 3 years ago

jax - jaxlib release v0.4.13

  • Changes

    • Added Windows CPU-only wheels to the jaxlib Pypi release.
  • Bug fixes

    • __cuda_array_interface__ was broken in previous jaxlib versions and is now fixed ({jax-issue}16440).
    • Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.

- Python
Published by skye almost 3 years ago

jax - JAX release v0.4.12

- Python
Published by skye about 3 years ago

jax - jaxlib release v0.4.12

- Python
Published by skye about 3 years ago

jax - Jax release v0.4.11

- Python
Published by yashk2810 about 3 years ago

jax - Jaxlib release v0.4.11

- Python
Published by yashk2810 about 3 years ago

jax - JAX release v0.4.10

- Python
Published by skye about 3 years ago

jax - jaxlib release v0.4.10

- Python
Published by skye about 3 years ago

jax - JAX release v0.4.9

- Python
Published by skye about 3 years ago

jax - jaxlib release v0.4.9

- Python
Published by skye about 3 years ago

jax - JAX release v0.4.8

- Python
Published by skye about 3 years ago

jax - Jaxlib release v0.4.7

- Python
Published by yashk2810 about 3 years ago

jax - Jax release v0.4.7

- Python
Published by yashk2810 about 3 years ago

jax - jaxlib release v0.4.6

- Python
Published by skye over 3 years ago

jax - JAX release v0.4.6

- Python
Published by skye over 3 years ago

jax - JAX release v0.4.5

jax version 0.4.5

- Python
Published by skye over 3 years ago

jax - Jaxlib release v0.4.4

- Python
Published by yashk2810 over 3 years ago

jax - Jax release v0.4.4

- Python
Published by yashk2810 over 3 years ago

jax - JAX release v0.4.3

- Python
Published by skye over 3 years ago

jax - jaxlib release v0.4.3

- Python
Published by skye over 3 years ago

jax - JAX release v0.4.2

- Python
Published by skye over 3 years ago

jax - jaxlib release v0.4.2

- Python
Published by skye over 3 years ago

jax - Jaxlib release v0.4.1

  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
    • The behavior of XLA_PYTHON_CLIENT_MEM_FRACTION=.XX has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details.
    • The deprecated method .block_host_until_ready() has been removed. Use .block_until_ready() instead.

- Python
Published by yashk2810 over 3 years ago

jax - Jax release v0.4.1

  • Changes
    • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
    • We introduce jax.Array which is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. jax.Array has been enabled by default in JAX 0.4 and makes some breaking change to the pjit API. The jax.Array migration guide can help you migrate your codebase to jax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.
    • PartitionSpec and Mesh are now out of experimental. The new API endpoints are jax.sharding.PartitionSpec and jax.sharding.Mesh. jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are deprecated and will be removed in 3 months.
    • with_sharding_constraints new public endpoint is jax.lax.with_sharding_constraint.
    • If using ABSL flags together with jax.config, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading jax.config options, which are used pervasively in JAX.
    • The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
    • A number of jax.numpy functions now have their arguments marked as positional-only, matching NumPy.
    • jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibility policy. It can be replaced with jnp.sort(a, axis=0).

- Python
Published by yashk2810 over 3 years ago

jax - Jax release v0.3.25

- Python
Published by yashk2810 over 3 years ago

jax - Jaxlib release v0.3.25

- Python
Published by yashk2810 over 3 years ago

jax - Jaxlib release v0.3.24

- Python
Published by yashk2810 over 3 years ago

jax - Jax release v0.3.24

- Python
Published by yashk2810 over 3 years ago

jax -

  • Changes
    • Update Colab TPU driver version for new jaxlib release.

- Python
Published by skye over 3 years ago

jax -

- Python
Published by skye over 3 years ago

jax -

  • Changes
    • Add JAX_PLATFORMS=tpu,cpu as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. Set JAX_PLATFORMS='' to override this behavior and automatically choose an available backend (the original default), or set JAX_PLATFORMS=cpu to always use CPU regardless of if the TPU is available.
  • Deprecations
    • Several test utilities deprecated in JAX v0.3.8 are now removed from {mod}jax.test_util.

- Python
Published by skye over 3 years ago

jax - JAX release v0.3.21

  • Changes
    • The persistent compilation cache will now warn instead of raising an exception on error ({jax-issue}#12582), so program execution can continue if something goes wrong with the cache. Set JAX_RAISE_PERSISTENT_CACHE_ERRORS=true to revert this behavior.

- Python
Published by skye over 3 years ago

jax - JAX release v0.3.20

Notable changes:

  • Adds missing .pyi files that were missing from the previous release (#12536).
  • Fixes an incompatibility between jax 0.3.19 and the libtpu version it pinned (#12550). Requires jaxlib 0.3.20.
  • Fix incorrect pip url in setup.py comment (#12528).

- Python
Published by hawkinsp over 3 years ago

jax - jaxlib release v0.3.20

Notable changes: * Fixes support for limiting the visible CUDA devices viajax_cuda_visible_devices in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU (#12533).

- Python
Published by hawkinsp over 3 years ago

jax - JAX release v0.3.19

Fixes the required jaxlib version

- Python
Published by skye over 3 years ago

jax - JAX release v0.3.18

  • GitHub commits.
  • Changes
    • Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}#7733) is stable and public. See the overview and the API docs for {mod}jax.stages.
    • Introduced {class}jax.Array, intended to be used for both isinstance checks and type annotations for array types in JAX. Notice that this included some subtle changes to how isinstance works for {class}jax.numpy.ndarray for jax-internal objects, as {class}jax.numpy.ndarray is now a simple alias of {class}jax.Array.
  • Breaking changes
    • jax._src is no longer imported into the from the public jax namespace. This may break users that were using JAX internals.
    • jax.soft_pmap has been deleted. Please use pjit or xmap instead. jax.soft_pmap is undocumented. If it were documented, a deprecation period would have been provided.

- Python
Published by skye over 3 years ago

jax - JAX release v0.3.17

  • GitHub commits.
  • Bugs
    • Fix corner case issue in gradient of lax.pow with an exponent of zero (#12041)
  • Breaking changes
    • jax.checkpoint, also known as jax.remat, no longer supports the concrete option, following the previous version's deprecation; see JEP 11830.
  • Changes
    • Added jax.pure_callback that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit or jax.pmap).
  • Deprecations:
    • The deprecated DeviceArray.tile() method has been removed. Use jax.numpy.tile (#11944).
    • DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead.

- Python
Published by jakevdp almost 4 years ago

jax - JAX release v0.3.16

  • GitHub commits.
  • Breaking changes
    • Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
  • Changes
    • Added jax.debug that includes utilities for runtime value debugging such at jax.debug.print and jax.debug.breakpoint.
    • Added new documentation for runtime value debugging
  • Deprecations
    • jax.mask jax.shapecheck APIs have been removed. See #11557.
    • jax.experimental.loops has been removed. See #10278 for an alternative API.
    • jax.tree_util.tree_multimap has been removed. It has been deprecated since JAX release 0.3.5, and jax.tree_util.tree_map is a direct replacement.
    • Removed jax.experimental.stax; it has long been a deprecated alias of jax.example_libraries.stax.
    • Removed jax.experimental.optimizers; it has long been a deprecated alias of jax.example_libraries.optimizers.
    • jax.checkpoint, also known as jax.remat, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.

- Python
Published by sharadmv almost 4 years ago

jax - jax version 0.3.15

- Python
Published by skye almost 4 years ago

jax - jaxlib version 0.3.15

- Python
Published by skye almost 4 years ago

jax - JAX release v0.3.14

- Python
Published by sharadmv almost 4 years ago

jax - Jaxlib release v0.3.14

- Python
Published by sharadmv almost 4 years ago

jax - JAX release v0.3.13

- Python
Published by skye about 4 years ago

jax - Jax release v0.3.12

  • Changes
    • Fixes https://github.com/google/jax/pull/10717

- Python
Published by yashk2810 about 4 years ago

jax - Jax release v0.3.11

  • Changes
    • {func}jax.lax.eigh now accepts an optional sort_eigenvalues argument that allows users to opt out of eigenvalue sorting on TPU.
  • Deprecations
    • Non-array arguments to functions in {mod}jax.lax.linalg are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}jax.numpy.linalg instead.
    • {func}jax.scipy.linalg.polar_unitary, which was a JAX extension to the scipy API, is deprecated. Use {func}jax.scipy.linalg.polar instead.

- Python
Published by yashk2810 about 4 years ago

jax - Jaxlib release v0.3.10

- Python
Published by yashk2810 about 4 years ago

jax - Jax release v0.3.10

- Python
Published by yashk2810 about 4 years ago

jax - Jax release 0.3.9

  • Changes
    • Added support for fully asynchronous checkpointing for GlobalDeviceArray.

- Python
Published by yashk2810 about 4 years ago

jax - JAX release v0.3.8

  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}#10452)
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

- Python
Published by mattjj about 4 years ago

jax - Jaxlib v0.3.7

  • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.

- Python
Published by hawkinsp about 4 years ago

jax - JAX release v0.3.7

  • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
  • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
  • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).

- Python
Published by hawkinsp about 4 years ago