Recent Releases of jax
jax - JAX v0.7.1
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only offered free-threading builds on Linux.
Changes
- Exposed
jax.set_meshwhich acts as a global setter and a context manager. Removedjax.sharding.use_meshin favor ofjax.set_mesh. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain supported.
jax.lax.dotnow implements the general dot product via the optionaldimension_numbersargument.
- Exposed
Deprecations:
jax.lax.zeros_like_arrayis deprecated. Please usejax.numpy.zeros_likeinstead.- Attempting to import
jax.experimental.host_callbacknow results in aDeprecationWarning, and will result in anImportErrorstarting in JAX v0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35. - In
jax.lax.dot, passing theprecisionandpreferred_element_typearguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad,jax.interpreters.batching, andjax.interpreters.partial_eval; they are used rarely if ever outside JAX itself, and most are deprecated without any public replacement.
- Python
Published by jakeharmon8 10 months ago
jax - JAX v0.7.0
New features:
- Added
jax.Pwhich is an alias forjax.sharding.PartitionSpec. - Added
jax.tree.reduce_associative.
- Added
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the migration guide for more information.
- JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval). See migration guide for more information.
jax.stages.OutInfohas been replaced withjax.ShapeDtypeStruct.jax.jitnow requiresfunto be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in an error starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026.
- Layout API renames:
Layout,.layout,.input_layoutsand.output_layoutshave been renamed toFormat,.format,.input_formatsand.output_formatsDeviceLocalLayout,.device_local_layouthave been renamed toLayoutand.layoutjax.experimental.shardmodule has been deleted and all the APIs have been moved to thejax.shardingendpoint. So usejax.sharding.reshard,jax.sharding.auto_axesandjax.sharding.explicit_axesinstead of their experimental endpoints.lax.infeedandlax.outfeedwere removed, after being deprecated in JAX 0.6. Thetransfer_to_infeedandtransfer_from_outfeedmethods were also removed theDeviceobjects.- The
jax.extend.core.primitives.pjit_pprimitive has been renamed tojit_p, and itsnameattribute has changed from"pjit"to"jit". This affects the string representations of jaxprs. The same primitive is no longer exported from thejax.experimental.pjitmodule. - The (undocumented) function
jax.extend.backend.add_clear_backends_callbackhas been removed. Users should usejax.extend.backend.register_backend_cacheinstead.
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPESis deprecated; please use the newjax.dlpack.is_supported_dtypefunction. jax.scipy.special.sph_harmhas been deprecated following a similar deprecation in SciPy; usejax.scipy.special.sph_harm_yinstead.- From {mod}
jax.interpreters.xla, the previously deprecated symbolsabstractifyandpytype_aval_mappingshave been removed. jax.interpreters.xla.canonicalize_dtypeis deprecated. For canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype. For checking whether an object is a valid jax input, preferjax.core.valid_jaxtype.- From {mod}
jax.core, the previously deprecated symbolsAxisName,ConcretizationTypeError,axis_frame,call_p,closed_call_p,get_type,trace_state_clean,typematch, andtypecheckhave been removed. - From {mod}
jax.lib.xla_client, the previously deprecated symbolsDeviceAssignment,get_topology_for_devices, andmlir_api_versionhave been removed. jax.extend.ffiwas removed after being deprecated in v0.5.0. Use {mod}jax.ffiinstead.jax.lib.xla_bridge.get_compile_optionsis deprecated, and replaced byjax.extend.backend.get_compile_options.
- {obj}
- Python
Published by MichaelHudgins 11 months ago
jax - JAX v0.6.2
New features:
- Added
jax.tree.broadcastwhich implements a pytree prefix broadcasting helper.
- Added
Changes
- The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.
- Python
Published by yashk2810 12 months ago
jax - JAX v0.6.1
New features:
- Added
jax.lax.axis_sizewhich returns the size of the mapped axis given its name.
- Added
Changes
- Additional checking for the versions of CUDA package dependencies was reenabled, having been accidentally disabled in a previous release.
- JAX nightly packages are now published to artifact registry. To install these packages, see the JAX installation guide.
jax.sharding.PartitionSpecno longer inherits from a tuple.jax.ShapeDtypeStructis immutable now. Please use.updatemethod to update yourShapeDtypeStructinstead of doing in-place updates.
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_pis deprecated, and will be removed in JAX v0.7.0.
- Python
Published by hawkinsp about 1 year ago
jax - JAX v0.6.0
Breaking changes
jax.numpy.arrayno longer acceptsNone. This behavior was deprecated since November 2023 and is now removed.- Removed the
config.jax_data_dependent_tracing_fallbackconfig option, which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery. - Removed the
config.jax_eager_pmapconfig option. - Disallow the calling of
lowerandtraceAOT APIs on the result ofjax.jitif there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to applyjax.jitlast among the wrappers, and similarly forjax.pmap. See#27873. - The
cuda12_pipextra forjaxhas been removed; usepip install jax[cuda12]instead.
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.
- JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously using
pip install jax[cuda12_local]to install JAX, runpip install jax[cuda12-local]instead. jax.jitnow requiresfunto be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
Deprecations
jax.tree_util.build_treeis deprecated. Usejax.tree.unflatteninstead.- Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call.
- All APIs in
jax.lib.xla_extensionare now deprecated. jax.interpreters.mlir.hloandjax.interpreters.mlir.func_dialect, which were accidental exports, have been removed. If needed, they are available fromjax.extend.mlir.jax.interpreters.mlir.custom_callis deprecated. The APIs provided byjax.ffishould be used instead.- The deprecated use of
jax.ffi.ffi_callwith inline arguments is no longer supported.jax.ffi.ffi_callnow unconditionally returns a callable. - The following exports in
jax.lib.xla_clientare deprecated:get_topology_for_devices,heap_profile,mlir_api_version,Client,CompileOptions,DeviceAssignment,Frame,HloSharding,OpSharding,Traceback. - The following internal APIs in
jax.utilare deprecated:HashableFunction,as_hashable_function,cache,safe_map,safe_zip,split_dict,split_list,split_list_checked,split_merge,subvals,toposort,unzip2,wrap_name, andwraps. jax.dlpack.to_dlpackhas been deprecated. You can usually pass a JAXArraydirectly to thefrom_dlpackfunction of another framework. If you need the functionality ofto_dlpack, use the__dlpack__attribute of an array.jax.lax.infeed,jax.lax.infeed_p,jax.lax.outfeed, andjax.lax.outfeed_pare deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client:ArrayImpl,FftType,PaddingType,PrimitiveType,XlaBuilder,dtype_to_etype,ops,register_custom_call_target,shape_from_pyval,Shape,XlaComputation. - From
jax.lib.xla_extension:ArrayImpl,XlaRuntimeError. - From
jax:jax.treedef_is_leaf,jax.tree_flatten,jax.tree_map,jax.tree_leaves,jax.tree_structure,jax.tree_transpose, andjax.tree_unflatten. Replacements can be found injax.treeorjax.tree_util. - From
jax.core:AxisSize,ClosedJaxpr,EvalTrace,InDBIdx,InputType,Jaxpr,JaxprEqn,Literal,MapPrimitive,OpaqueTraceState,OutDBIdx,Primitive,Token,TRACER_LEAK_DEBUGGER_WARNING,Var,concrete_aval,dedup_referents,escaped_tracer_error,extend_axis_env_nd,full_lower,get_referent,jaxpr_as_fun,join_effects,lattice_join,leaked_tracer_error,maybe_find_leaked_tracers,raise_to_shaped,raise_to_shaped_mappings,reset_trace_state,str_eqn_compact,substitute_vars_in_output_ty,typecompat, andused_axis_names_jaxpr. Most have no public replacement, though a few are available atjax.extend.core. - The
vectorizedargument tojax.pure_callbackandjax.ffi.ffi_call. Use thevmap_methodparameter instead.
- Python
Published by hawkinsp about 1 year ago
jax - JAX v0.5.3
New Features
- Added a
allow_negative_indicesoption tojax.lax.dynamic_slice,jax.lax.dynamic_update_sliceand related functions. The default is true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. - Added a
replaceoption tojax.random.categoricalto enable sampling without replacement.
- Added a
- Python
Published by hawkinsp about 1 year ago
jax - JAX v0.5.2
Patch release of 0.5.1
- Bug fixes
- Fixes TPU metric logging and
tpu-info, which was broken in 0.5.1
- Fixes TPU metric logging and
- Python
Published by skye over 1 year ago
jax - JAX v0.5.1
New Features
- Added an experimental
jax.experimental.custom_dce.custom_dcedecorator to support customizing the behavior of opaque functions under JAX-level dead code elimination (DCE). See#25956for more details. - Added low-level reduction APIs in {mod}
jax.lax:jax.lax.reduce_sum,jax.lax.reduce_prod,jax.lax.reduce_max,jax.lax.reduce_min,jax.lax.reduce_and,jax.lax.reduce_or, andjax.lax.reduce_xor. jax.lax.linalg.qr, andjax.scipy.linalg.qr, now support column-pivoting on CPU and GPU. See #20282 and #25955 for more details.
- Added an experimental
Changes
JAX_CPU_COLLECTIVES_IMPLEMENTATIONandJAX_NUM_CPU_DEVICESnow work as env vars. Before they could only be specified via jax.config or flags.JAX_CPU_COLLECTIVES_IMPLEMENTATIONnow defaults to'gloo', meaning multi-process CPU communication works out-of-the-box.- The
jax[tpu]TPU extra no longer depends on thelibtpu-nightlypackage. This package may safely be removed if it is present on your machine; JAX now useslibtpuinstead.
Deprecations
- The internal function
linear_util.wrap_initand the constructorcore.Jaxprnow must take a non-emptycore.DebugInfokwarg. For a limited time, aDeprecationWarningis printed ifjax.extend.linear_util.wrap_initis used without debugging info. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail.
- The internal function
Bug fixes
- TPU runtime startup and shutdown time should be significantly improved on
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(
sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'). We hope to improve this further in future releases. - Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZEis unset or set to -1, i.e. if the LRU eviction policy isn't enabled. This should improve performance when using the cache with large-scale network storage.
- TPU runtime startup and shutdown time should be significantly improved on
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
need to enable transparent hugepages in your VM image
(
- Python
Published by hawkinsp over 1 year ago
jax - JAX v0.5.0
As of this release, JAX now uses effort-based versioning. Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this.
Breaking changes
- Enable
jax_threefry_partitionableby default (see the update note). - This release drops support for Mac x86 wheels. Mac ARM of course remains supported. For a recent discussion, see https://github.com/jax-ml/jax/discussions/22936.
Two key factors motivated this decision: * The Mac x86 build (only) has a number of test failures and crashes. We would prefer to ship no release than a broken release. * Mac x86 hardware is end-of-life and cannot be easily obtained for developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again.
- Enable
Changes:
- The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025.
- The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum supported version until June 2025.
jax.numpy.einsumnow defaults tooptimize='auto'rather thanoptimize='optimal'. This avoids exponentially-scaling trace-time in the case of many arguments (#25214).jax.numpy.linalg.solveno longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, usesolve(a, b[..., None]).squeeze(-1).
New Features
jax.numpy.fft.fftn,jax.numpy.fft.rfftn,jax.numpy.fft.ifftn, andjax.numpy.fft.irfftnnow support transforms in more than 3 dimensions, which was previously the limit. See#25606for more details.- Support added for user defined state in the FFI via the new
jax.ffi.register_ffi_type_idfunction. - The AOT lowering
.as_text()method now supports thedebug_infooption to include debugging information, e.g., source location, in the output.
Deprecations
- From
jax.interpreters.xla,abstractifyandpytype_aval_mappingsare now deprecated, having been replaced by symbols of the same name injax.core. jax.scipy.special.lpmnandjax.scipy.special.lpmn_valuesare deprecated, following their deprecation in SciPy v1.15.0. There are no plans to replace these deprecated functions with new APIs.- The
jax.extend.ffisubmodule was moved tojax.ffi, and the previous import path is deprecated.
- From
Deletions
jax_enable_memoriesflag has been deleted and the behavior of that flag is on by default.- From
jax.lib.xla_client, the previously-deprecatedDeviceandXlaRuntimeErrorsymbols have been removed; instead usejax.Deviceandjax.errors.JaxRuntimeErrorrespectively. - The
jax.experimental.array_apimodule has been removed after being deprecated in JAX v0.4.32. Since that release,jax.numpysupports the array API directly.
- Python
Published by hawkinsp over 1 year ago
jax - JAX v0.4.38
Changes:
jax.tree.flatten_with_pathandjax.tree.map_with_pathare added as shortcuts of the correspondingtree_utilfunctions.
Deprecations
- a number of APIs in the internal
jax.corenamespace have been deprecated. Most were no-ops, were little-used, or can be replaced by APIs of the same name injax.extend.core; see the documentation for {mod}jax.extendfor information on the compatibility guarantees of these semi-public extensions. - Several previously-deprecated APIs have been removed, including:
- from
jax.core:check_eqn,check_type,check_valid_jaxtype, andnon_negative_dim. - from
jax.lib.xla_bridge:xla_clientanddefault_backend. - from
jax.lib.xla_client:_xlaandbfloat16. - from
jax.numpy:round_.
- a number of APIs in the internal
New Features
jax.export.exportcan be used for device-polymorphic export with shardings constructed with {func}jax.sharding.AbstractMesh. See the jax.export documentation.- Added
jax.lax.split. This is a primitive version ofjax.numpy.split, added because it yields a more compact transpose during automatic differentiation.
- Python
Published by hawkinsp over 1 year ago
jax - JAX v0.4.37
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
- Bug fixes
- Fixed a bug where
jitwould error if an argument was namedf(#25329). - Fix a bug that will throw
index out of rangeerror injax.lax.while_loopif the user registers pytree node class with different aux data for the flatten and flattenwithpath. - Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
- Fixed a bug where
- Python
Published by hawkinsp over 1 year ago
jax - JAX v0.4.36
Breaking Changes
- This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,
post_process_call,new_base_main,custom_bind, and so on. The change should only affect users that use JAX internals.
If you do use JAX internals then you may need to update your code (see https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallbackflag as a workaround, and if you need help updating your code then please file a bug. *jax.experimental.jax2tf.convertwithnative_serialization=Falseor withenable_xla=Falsehave been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases.jax2tfwith native serialization will still be supported. * Injax.interpreters.xla, thexb,xc, andxesymbols have been removed after being deprecated in JAX v0.4.31. Instead usexb = jax.lib.xla_bridge,xc = jax.lib.xla_client, andxe = jax.lib.xla_extension. * The deprecated modulejax.experimental.exporthas been removed. It was replaced byjax.exportin JAX v0.4.30. See the migration guide for information on migrating to the new API. * Theinitialargument tojax.nn.softmaxandjax.nn.log_softmaxhas been removed, after being deprecated in v0.4.27. * Callingnp.asarrayon typed PRNG keys (i.e. keys produced byjax.random.key) now raises an error. Previously, this returned a scalar object array. * The following deprecated methods and functions injax.exporthave been removed: *jax.export.DisabledSafetyCheck.shape_assertions: it had no effect already. *jax.export.Exported.lowering_platforms: useplatforms. *jax.export.Exported.mlir_module_serialization_version: usecalling_convention_version. *jax.export.Exported.uses_shape_polymorphism: useuses_global_constants. * thelowering_platformskwarg forjax.export.export: useplatformsinstead. * The kwargssymbolic_scopeandsymbolic_constraintsfromjax.export.symbolic_args_specshave been removed. They were deprecated in June 2024. Usescopeandconstraintsinstead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in aTypeError. * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and replaces previous build.py usage. Runpython build/build.py --helpfor more details. Brief overview of the new subcommand options: *build: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt*requirements_update: Updates requirementslock.txt files. *jax.scipy.linalg.toeplitznow does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can calljax.numpy.ravelon the function inputs. *jax.scipy.special.gammaandjax.scipy.special.gammasgnnow return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827. * `jax.clearbackendswas removed after being deprecated in v0.4.26. * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use thedisabled_checks` parameter. See more details in the documentation.- This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,
New Features
jax.jitgot a newcompiler_options: dict[str, Any]argument, for passing compilation options to XLA. For the moment it's undocumented and may be in flux.jax.tree_util.register_dataclassnow allows metadata fields to be declared inline viadataclasses.field. See the function documentation for examples.- Added
jax.numpy.put_along_axis. jax.lax.linalg.eigand the relatedjax.numpyfunctions (jax.numpy.linalg.eigandjax.numpy.linalg.eigvals) are now supported on GPU. See #24663 for more details.- Added two new configuration flags,
jax_exec_time_optimization_effortandjax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
Bug fixes
- Fixed a bug where the GPU implementations of LU and QR decomposition would result in an indexing overflow for batch sizes close to int32 max. See #24843 for more details.
Deprecations
jax.lib.xla_extension.ArrayImplandjax.lib.xla_client.ArrayImplare deprecated; usejax.Arrayinstead.jax.lib.xla_extension.XlaRuntimeErroris deprecated; usejax.errors.JaxRuntimeErrorinstead.
- Python
Published by hawkinsp over 1 year ago
jax - JAX v0.4.35
Breaking Changes
jax.numpy.isscalarnow returns True for any array-like object with zero dimensions. Previously it only returned True for zero-dimensional array-like objects with a weak dtype.jax.experimental.host_callbackhas been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See#20385for a discussion of alternatives.
Changes:
jax.lax.FftTypewas introduced as a public name for the enum of FFT operations. The semi-public APIjax.lib.xla_client.FftTypehas been deprecated.- TPU: JAX now installs TPU support from the
libtpupackage rather thanlibtpu-nightly. For the next few releases JAX will pin an empty version oflibtpu-nightlyas well aslibtputo ease the transition; that dependency will be removed in Q1 2025.
Deprecations:
- The semi-public API
jax.lib.xla_client.PaddingTypehas been deprecated. No JAX APIs consume this type, so there is no replacement. - The default behavior of
jax.pure_callbackandjax.extend.ffi.ffi_callundervmaphas been deprecated and so has thevectorizedparameter to those functions. Thevmap_methodparameter should be used instead for better defined behavior. See the discussion in#23881for more details. - The semi-public API
jax.lib.xla_client.register_custom_call_targethas been deprecated. Use the JAX FFI instead. - The semi-public APIs
jax.lib.xla_client.dtype_to_etype,jax.lib.xla_client.ops,jax.lib.xla_client.shape_from_pyval,jax.lib.xla_client.PrimitiveType,jax.lib.xla_client.Shape,jax.lib.xla_client.XlaBuilder, andjax.lib.xla_client.XlaComputationhave been deprecated. Use StableHLO instead.
- The semi-public API
- Python
Published by hawkinsp over 1 year ago
jax - JAX v0.4.34
New Functionality
- This release includes wheels for Python 3.13. Free-threading mode is not yet supported.
jax.errors.JaxRuntimeErrorhas been added as a public alias for the formerly privateXlaRuntimeErrortype.
Breaking changes
jax_pmap_no_rank_reductionflag is set toTrueby default.array[0]on a pmap result now introduces a reshape (usearray[0:1]instead).- The per-shard shape (accessable via
jax_array.addressable_shardsorjax_array.addressable_data(0))now has a leading(1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit. jax.experimental.host_callbackhas been deprecated since March 2024, with JAX version 0.4.26. Now we set the default value of the--jax_host_callback_legacyconfiguration value toTrue, which means that if your code usesjax.experimental.host_callbackAPIs, those API calls will be implemented in terms of the newjax.experimental.io_callbackAPI. If this breaks your code, for a very limited time, you can set the--jax_host_callback_legacytoTrue. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See #20385 for a discussion.
Deprecations
- In
jax.numpy.trim_zeros, non-arraylike arguments or arraylike arguments withndim != 1are now deprecated, and in the future will result in an error. - Internal pretty-printing tools
jax.core.pp_*have been removed, after being deprecated in JAX v0.4.30. jax.lib.xla_client.Deviceis deprecated; usejax.Deviceinstead.jax.lib.xla_client.XlaRuntimeErrorhas been deprecated. Usejax.errors.JaxRuntimeErrorinstead.
- In
Deletion:
jax.xla_computationis deleted. It has been 3 months since its deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality asjax.xla_computation.jax.xla_computation(fn)(*args, **kwargs)can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').- You can also use
.out_infoproperty ofjax.stages.Loweredto get the output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo'). jax.ShapeDtypeStructno longer accepts thenamed_shapeargument. The argument was only used byxmapwhich was removed in 0.4.31.jax.tree.map(f, None, non-None), which previously emitted aDeprecationWarning, now raises an error.Noneis only a tree-prefix of itself. To preserve the current behavior, you can askjax.tree.mapto treatNoneas a leaf value by writing:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).jax.sharding.XLACompatibleShardinghas been removed. Please usejax.sharding.Sharding.
Bug fixes
- Fixed a bug where
jax.numpy.cumsumwould produce incorrect outputs if a non-boolean input was provided anddtype=boolwas specified. - Edit implementation of
jax.numpy.ldexpto get correct gradient.
- Fixed a bug where
- Python
Published by hawkinsp over 1 year ago
jax - JAX release v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that release.
A TPU-only data corruption bug was found in the version of libtpu pinned by JAX 0.4.32, which manifested only if multiple TPU slices were present in the same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu-nightly.
This release also fixes an inaccurate result for F64 tanh on CPU (#23590).
- Python
Published by hawkinsp almost 2 years ago
jax - JAX release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
- Python
Published by hawkinsp almost 2 years ago
jax - Jaxlib release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
- Python
Published by hawkinsp almost 2 years ago
jax - JAX v0.4.29
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
pip install jax[cuda12]). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.exportAPI. It is not possible anymore to usefrom jax.experimental.export import export, and instead you should usefrom jax.experimental import export. The removed functionality has been deprecated since 0.4.24.
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
Deprecations
jax.sharding.XLACompatibleShardingis deprecated. Please usejax.sharding.Sharding.jax.experimental.Exported.in_shardingshas been renamed asjax.experimental.Exported.in_shardings_hlo. Same forout_shardings. The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core:non_negative_dim,DimSize,Shape - from {mod}
jax.lax:tie_in - from {mod}
jax.nn:normalize - from {mod}
jax.interpreters.xla:backend_specific_translations,translations,register_translation,xla_destructure,TranslationRule,TranslationContext,XlaOp. - The
tolargument of {func}jax.numpy.linalg.matrix_rankis being deprecated and will soon be removed. Usertolinstead. - The
rcondargument of {func}jax.numpy.linalg.pinvis being deprecated and will soon be removed. Usertolinstead. - The deprecated
jax.configsubmodule has been removed. To configure JAX useimport jaxand then reference the config object viajax.config. - {mod}
jax.randomAPIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}jax.vmapin such cases.
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jaxto construct shardings that can be used with the JAX APIs from the HloShardings that are stored in theExportedobjects.
- Added {func}
- Python
Published by hawkinsp about 2 years ago
jax - Jaxlib release v0.4.29
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly, which manifested as an incorrect output for cumulative reductions (#21403).
- Fixed a bug where XLA:CPU miscompiled certain matmul fusions (https://github.com/openxla/xla/pull/13301).
- Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).
Deprecations
jax.tree.map(f, None, non-None)now emits aDeprecationWarning, and will raise an error in a future version of jax.Noneis only a tree-prefix of itself. To preserve the current behavior, you can askjax.tree.mapto treatNoneas a leaf value by writing:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
- Python
Published by hawkinsp about 2 years ago
jax - JAX v0.4.28
Bug fixes
- Reverted a change to
make_jaxprthat was breaking Equinox (#21116).
- Reverted a change to
Deprecations & removals
- The
kindargument tojax.numpy.sortandjax.numpy.argsortis now removed. Usestable=Trueorstable=Falseinstead. - Removed
get_compute_capabilityfrom thejax.experimental.pallas.gpumodule. Use thecompute_capabilityattribute of a GPU device, returned byjax.devicesorjax.local_devices, instead.
- The
Changes
- The minimum jaxlib version of this release is 0.4.27.
- Python
Published by hawkinsp about 2 years ago
jax - jaxlib v0.4.28
Bug fixes
- Fixes a memory corruption bug in the type name of Array and JIT Python objects in Python 3.10 or earlier.
- Fixed a warning
'+ptx84' is not a recognized feature for this targetunder CUDA 12.4. - Fixed a slow compilation problem on CPU.
Changes
- The Windows build is now built with Clang instead of MSVC.
- Python
Published by hawkinsp about 2 years ago
jax - JAX release v0.4.13
NOTE: This is the last JAX release that will include Python 3.8 support
Changes
jax.jitnow allowsNoneto be passed toin_shardingsandout_shardings. The semantics are as follows:- For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
jax.experimental.pjit.pjitalso allowsNoneto be passed toin_shardingsandout_shardings. The semantics are as follows:- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.
- For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
- If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
- Executable.cost_analysis() works on Cloud TPU
- Added a warning if a non-allowlisted
jaxlibplugin is in use. - Added
jax.tree_util.tree_leaves_with_path.
Bug fixes
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
cudnn89instead ofcudnn88.
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
Deprecations
- The
native_serialization_strict_checksparameter to {func}jax.experimental.jax2tf.convertis deprecated in favor of the newnative_serializaation_disabled_checks({jax-issue}#16347).
- The
- Python
Published by skye almost 3 years ago
jax - jaxlib release v0.4.13
Changes
- Added Windows CPU-only wheels to the
jaxlibPypi release.
- Added Windows CPU-only wheels to the
Bug fixes
__cuda_array_interface__was broken in previous jaxlib versions and is now fixed ({jax-issue}16440).- Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
- Python
Published by skye almost 3 years ago
jax - Jaxlib release v0.4.1
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
version-support-policy. - The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XXhas been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details. - The deprecated method
.block_host_until_ready()has been removed. Use.block_until_ready()instead.
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
- Python
Published by yashk2810 over 3 years ago
jax - Jax release v0.4.1
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
version-support-policy. - We introduce
jax.Arraywhich is a unified array type that subsumesDeviceArray,ShardedDeviceArray, andGlobalDeviceArraytypes in JAX. Thejax.Arraytype helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unifyjitandpjit.jax.Arrayhas been enabled by default in JAX 0.4 and makes some breaking change to thepjitAPI. The jax.Array migration guide can help you migrate your codebase tojax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts. PartitionSpecandMeshare now out of experimental. The new API endpoints arejax.sharding.PartitionSpecandjax.sharding.Mesh.jax.experimental.maps.Meshandjax.experimental.PartitionSpecare deprecated and will be removed in 3 months.with_sharding_constraints new public endpoint isjax.lax.with_sharding_constraint.- If using ABSL flags together with
jax.config, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of readingjax.configoptions, which are used pervasively in JAX. - The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
- A number of
jax.numpyfunctions now have their arguments marked as positional-only, matching NumPy. jnp.msortis now deprecated, following the deprecation ofnp.msortin numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibilitypolicy. It can be replaced withjnp.sort(a, axis=0).
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
- Python
Published by yashk2810 over 3 years ago
jax -
- Changes
- Update Colab TPU driver version for new jaxlib release.
- Python
Published by skye over 3 years ago
jax -
- Changes
- Add
JAX_PLATFORMS=tpu,cpuas default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. SetJAX_PLATFORMS=''to override this behavior and automatically choose an available backend (the original default), or setJAX_PLATFORMS=cputo always use CPU regardless of if the TPU is available.
- Add
- Deprecations
- Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}
jax.test_util.
- Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}
- Python
Published by skye over 3 years ago
jax - JAX release v0.3.21
- Changes
- The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}
#12582), so program execution can continue if something goes wrong with the cache. SetJAX_RAISE_PERSISTENT_CACHE_ERRORS=trueto revert this behavior.
- The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}
- Python
Published by skye over 3 years ago
jax - JAX release v0.3.20
Notable changes:
- Adds missing
.pyifiles that were missing from the previous release (#12536). - Fixes an incompatibility between
jax0.3.19 and the libtpu version it pinned (#12550). Requires jaxlib 0.3.20. - Fix incorrect
pipurl insetup.pycomment (#12528).
- Python
Published by hawkinsp over 3 years ago
jax - jaxlib release v0.3.20
Notable changes:
* Fixes support for limiting the visible CUDA devices viajax_cuda_visible_devices in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU (#12533).
- Python
Published by hawkinsp over 3 years ago
jax - JAX release v0.3.19
Fixes the required jaxlib version
- Python
Published by skye over 3 years ago
jax - JAX release v0.3.18
- GitHub commits.
- Changes
- Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}
#7733) is stable and public. See the overview and the API docs for {mod}jax.stages. - Introduced {class}
jax.Array, intended to be used for bothisinstancechecks and type annotations for array types in JAX. Notice that this included some subtle changes to howisinstanceworks for {class}jax.numpy.ndarrayfor jax-internal objects, as {class}jax.numpy.ndarrayis now a simple alias of {class}jax.Array.
- Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}
- Breaking changes
jax._srcis no longer imported into the from the publicjaxnamespace. This may break users that were using JAX internals.jax.soft_pmaphas been deleted. Please usepjitorxmapinstead.jax.soft_pmapis undocumented. If it were documented, a deprecation period would have been provided.
- Python
Published by skye over 3 years ago
jax - JAX release v0.3.17
- GitHub commits.
- Bugs
- Fix corner case issue in gradient of
lax.powwith an exponent of zero (#12041)
- Fix corner case issue in gradient of
- Breaking changes
jax.checkpoint, also known asjax.remat, no longer supports theconcreteoption, following the previous version's deprecation; see JEP 11830.
- Changes
- Added
jax.pure_callbackthat enables calling back to pure Python functions from compiled functions (e.g. functions decorated withjax.jitorjax.pmap).
- Added
- Deprecations:
- The deprecated
DeviceArray.tile()method has been removed. Usejax.numpy.tile(#11944). DeviceArray.to_py()has been deprecated. Usenp.asarray(x)instead.
- The deprecated
- Python
Published by jakevdp almost 4 years ago
jax - JAX release v0.3.16
- GitHub commits.
- Breaking changes
- Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
- Changes
- Added
jax.debugthat includes utilities for runtime value debugging such atjax.debug.printandjax.debug.breakpoint. - Added new documentation for runtime value debugging
- Added
- Deprecations
jax.maskjax.shapecheckAPIs have been removed. See #11557.jax.experimental.loopshas been removed. See #10278 for an alternative API.jax.tree_util.tree_multimaphas been removed. It has been deprecated since JAX release 0.3.5, andjax.tree_util.tree_mapis a direct replacement.- Removed
jax.experimental.stax; it has long been a deprecated alias ofjax.example_libraries.stax. - Removed
jax.experimental.optimizers; it has long been a deprecated alias ofjax.example_libraries.optimizers. jax.checkpoint, also known asjax.remat, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.
- Python
Published by sharadmv almost 4 years ago
jax - Jax release v0.3.12
- Changes
- Fixes https://github.com/google/jax/pull/10717
- Python
Published by yashk2810 about 4 years ago
jax - Jax release v0.3.11
- Changes
- {func}
jax.lax.eighnow accepts an optionalsort_eigenvaluesargument that allows users to opt out of eigenvalue sorting on TPU.
- {func}
- Deprecations
- Non-array arguments to functions in {mod}
jax.lax.linalgare now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}jax.numpy.linalginstead. - {func}
jax.scipy.linalg.polar_unitary, which was a JAX extension to the scipy API, is deprecated. Use {func}jax.scipy.linalg.polarinstead.
- Non-array arguments to functions in {mod}
- Python
Published by yashk2810 about 4 years ago
jax - Jax release 0.3.9
- Changes
- Added support for fully asynchronous checkpointing for GlobalDeviceArray.
- Python
Published by yashk2810 about 4 years ago
jax - JAX release v0.3.8
- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svdon TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.condon TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinvon TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rankon TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vqhas been added. jax.experimental.maps.meshhas been deleted. Please usejax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.- {func}
jax.scipy.linalg.qrnow returns a length-1 tuple rather than the raw array whenmode='r', in order to match the behavior ofscipy.linalg.qr({jax-issue}#10452) - {func}
jax.numpy.take_along_axisnow takes an optionalmodeparameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip". - {func}
jax.numpy.takenow defaults tomode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...), now have"drop"semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axisnow raises aTypeErrorif its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_indexnow raises aTypeErrorif itsdimsargument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integerdimswas silently cast to integers. - {func}
jax.numpy.splitnow raises aTypeErrorif itsaxisargument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integeraxiswas silently cast to integers. - {func}
jax.numpy.indicesnow raises aTypeErrorif its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diagnow raises aTypeErrorif itskargument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integerkwas silently cast to integers. - Added {func}
jax.random.orthogonal.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_utilare now deprecated and will raise a warning on import. This includescases_from_list,check_close,check_eq,device_under_test,format_shape_dtype_string,rand_uniform,skip_on_devices,with_config,xla_bridge, and_default_tolerance({jax-issue}#10389). These, along with previously-deprecatedJaxTestCase,JaxTestLoader, andBufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}
- Python
Published by mattjj about 4 years ago
jax - Jaxlib v0.3.7
- Linux wheels are now built conforming to the
manylinux2014standard, instead ofmanylinux2010.
- Python
Published by hawkinsp about 4 years ago
jax - JAX release v0.3.7
- Fixed a performance problem if the indices passed to
jax.numpy.take_along_axiswere broadcasted (#10281). jax.scipy.special.expitandjax.scipy.special.logitnow require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.- The
DeviceArray.tile()method is deprecated, because numpy arrays do not have atile()method. As a replacement for this, use jax.numpy.tile (#10266).
- Python
Published by hawkinsp about 4 years ago