JAXbind

JAXbind: Bind any function to JAX - Published in JOSS (2024)

https://github.com/nifty-ppl/jaxbind

Science Score: 95.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 4 DOI reference(s) in README and JOSS metadata
  • Academic publication links
  • Committers with academic emails
    3 of 4 committers (75.0%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software
Last synced: 6 months ago · JSON representation

Repository

Bind any function written in another language to JAX with support for JVP/VJP/batching/jit compilation

Basic Info
  • Host: GitHub
  • Owner: NIFTy-PPL
  • License: bsd-2-clause
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 335 KB
Statistics
  • Stars: 73
  • Watchers: 5
  • Forks: 0
  • Open Issues: 3
  • Releases: 3
Created about 2 years ago · Last pushed 10 months ago
Metadata Files
Readme License

README.md

JAXbind: Bind any function to JAX

JAXbind API documentation: nifty-ppl.github.io/JAXbind/ | Found a bug? github.com/NIFTy-PPL/JAXbind/issues | Need help? github.com/NIFTy-PPL/JAXbind/discussions

Summary

The existing interface in JAX for connecting fully differentiable custom code requires deep knowledge of JAX and its C++ backend. The aim of JAXbind is to drastically lower the burden of connecting custom functions implemented in other programming languages to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive. In contrast, the JAX built-in external callback interface also has a Python endpoint but the external callbacks cannot be fully integrated into the JAX transformation engine, as only the Jacobian-vector product or the vector-Jacobian product can be added but not both.

Automatic Differentiation and Code Example

Automatic differentiation is a core feature of JAX and often one of the main reasons for using it. Thus, it is essential that custom functions registered with JAX support automatic differentiation. In the following, we will outline which functions our package respectively JAX requires to enable automatic differentiation. For simplicity, we assume that we want to connect the nonlinear function $f(x1,x2) = x1x2^2$ to JAX. The JAXbind package expects the Python function for $f$ to take three positional arguments. The first argument, out, is a tuple into which the function results are written. The second argument is also a tuple containing the input to the function, in our case, $x1$ and $x2$. Via kwargs_dump, potential keyword arguments given to the later registered Jax primitive can be forwarded to f in serialized form.

```python import jaxbind

def f(out, args, kwargsdump): kwargs = jaxbind.loadkwargs(kwargs_dump) x1, x2 = args out[0][()] = x1 * x2**2 ```

JAX's automatic differentiation engine can compute the Jacobian-vector product jvp and vector-Jacobian product vjp of JAX primitives. The Jacobian-vector product in JAX is a function applying the Jacobian of $f$ at a position $x$ to a tangent vector. In mathematical nomenclature this operation is called the pushforward of $f$ and can be denoted as $\partial f(x): Tx X \mapsto T{f(x)} Y$, with $Tx X$ and $T{f(x)} Y$ being the tangent spaces of $X$ and $Y$ at the positions $x$ and $f(x)$. As the implementation of $f$ is not JAX native, JAX cannot automatically compute the jvp. Instead, an implementation of the pushforward has to be provided, which JAXbind will register as the jvp of the JAX primitive of $f$. For our example, this Jacobian-vector-product function is given by $\partial f(x1,x2)(dx1,dx2) = x2^2dx1 + 2x1x2dx_2$.

python def f_jvp(out, args, kwargs_dump): kwargs = jaxbind.load_kwargs(kwargs_dump) x1, x2, dx1, dx2 = args out[0][()] = x2**2 * dx1 + 2 * x1 * x2 * dx2

The vector-Jacobian product vjp in JAX is the linear transpose of the Jacobian-vector product. In mathematical nomenclature this is the pullback $(\partial f(x))^{T}: T{f(x)}Y \mapsto Tx X$ of $f$. Analogously to the jvp, the user has to implement this function as JAX cannot automatically construct it. For our example function, the vector-Jacobian product is $(\partial f(x1,x2))^{T}(dy) = (x2^2dy, 2x1x_2dy)$.

python def f_vjp(out, args, kwargs_dump): kwargs = jaxbind.load_kwargs(kwargs_dump) x1, x2, dy = args out[0][()] = x2**2 * dy out[1][()] = 2 * x1 * x2 * dy

To just-in-time compile the function, JAX needs to abstractly evaluate the code, i.e. it needs to be able to know the shape and dtype of the output of the custom function given only the shape and dtype of the input. We have to provide these abstract evaluation functions returning the output shape and dtype given an input shape and dtype for f as well as for the vjp application. The output shape of the jvp is identical to the output shape of f itself and does not need to be specified again. Due to the internals of JAX the abstract evaluation functions take normal keyword arguments and not serialized keyword arguments.

```python def f_abstract(args, *kwargs): assert args[0].shape == args[1].shape return ((args[0].shape, args[0].dtype),)

def fabstractT(args, *kwargs): return ( (args[0].shape, args[0].dtype), (args[0].shape, args[0].dtype), ) ```

We have now defined all ingredients necessary to register a JAX primitive for our function $f$ using the JAXbind package.

python f_jax = jaxbind.get_nonlinear_call( f, (f_jvp, f_vjp), f_abstract, f_abstract_T )

f_jax is a JAX primitive registered via the JAXbind package supporting all JAX transformations. We can now compute the jvp and vjp of the new JAX primitive and even jit-compile and batch it.

```python import jax import jax.numpy as jnp

inp = (jnp.full((4,3), 4.), jnp.full((4,3), 2.)) tan = (jnp.full((4,3), 1.), jnp.full((4,3), 1.)) res, restan = jax.jvp(fjax, inp, tan)

cotan = [jnp.full((4,3), 6.)] res, fvjp = jax.vjp(fjax, *inp) rescotan = fvjp(cotan)

fjaxjit = jax.jit(fjax) res = fjax_jit(*inp) ```

Higher Order Derivatives and Linear Functions

JAX supports higher order derivatives and can differentiate a jvp or vjp with respect to the position at which the Jacobian was taken. Similar to first derivatives, JAX can not automatically compute higher derivatives of a general function $f$ that is not natively implemented in JAX. Higher order derivatives would again need to be provided by the user. For many algorithms, first derivatives are sufficient, and higher order derivatives are often not implemented by the high-performance codes. Therefore, the current interface of JAXbind is, for simplicity, restricted to first derivatives. In the future, the interface could be easily expanded if specific use cases require higher order derivatives.

In scientific computing, linear functions such as, e.g., spherical harmonic transforms are widespread. If the function $f$ is linear, differentiation becomes trivial. Specifically for a linear function $f$, the pushforward respectively the jvp of $f$ is identical to $f$ itself and independent of the position at which it is computed. Expressed in formulas, $\partial f(x)(dx) = f(dx)$ if $f$ is linear in $x$. Analogously, the pullback respectively the vjp becomes independent of the initial position and is given by the linear transpose of $f$, thus $(\partial f(x))^{T}(dy) = f^T(dy)$. Also, all higher order derivatives can be expressed in terms of $f$ and its transpose. To make use of these simplifications, JAXbind provides a special interface for linear functions, supporting higher order derivatives, only requiring an implementation of the function and its transpose.

Demos and Documentation

Additional demos can be found in the demos folder. Specifically, there is a basic demo 01linearfunction.py showcasing the interface for linear functions and custom batching rules. 02multilinearfunction.py binds a multi-linear function as a JAX primitive. Finally, 03nonlinearfunction.py demonstrates the interface for non-linear functions and shows how to deal with fixed arguments, which cannot be differentiated. JAXbind provides bindings to parts of the functionality of the DUCC package. The DUCC bindings are also exposed as a webpage to showcase a real-world example of the usage of JAXbind. The documentation of the JAXbind API is available here.

Platforms

Currently, JAXbind only has CPU but no GPU support. With some expertise on Python bindings for GPU kernels adding GPU support should be fairly simple. The interfacing with the JAX automatic differentiation engine is identical for CPU and GPU. Contributions are welcome!

Installation

Binary wheels for JAXbind can be obtained and installed from PyPI via:

pip install jaxbind

To install JAXbind from source, clone the repository and install the package via pip.

git clone https://github.com/NIFTy-PPL/jaxbind.git cd jaxbind pip install .

Contributing

Contributions are highly appreciated! Please open an issue first if you think your PR changes current code substantially. Please format your code using black. PRs affecting the public API, including adding new features, should update the public documentation. If possible, add appropriate tests to your PR. Feel free to open a PR early on in the development process, we are happy to help in the development process and provide feedback along the way.

Licensing terms

All source code in this package is released under the 2-clause BSD license. All of JAXbind is distributed without any warranty.

Citing JAXbind

To cite JAXbind, please use the citation provided below.

@article{jaxbind, title = {JAXbind: Bind any function to JAX}, author = {Jakob Roth and Martin Reinecke and Gordian Edenhofer}, year = {2024}, journal = {Journal of Open Source Software}, publisher = {The Open Journal}, volume = {9}, number = {98}, pages = {6532}, doi = {10.21105/joss.06532}, url = {https://doi.org/10.21105/joss.06532}, }

Owner

  • Name: NIFTy-PPL
  • Login: NIFTy-PPL
  • Kind: organization

JOSS Publication

JAXbind: Bind any function to JAX
Published
June 21, 2024
Volume 9, Issue 98, Page 6532
Authors
Jakob Roth ORCID
Max Planck Institute for Astrophysics, Karl-Schwarzschild-Str. 1, 85748 Garching, Germany, Ludwig Maximilian University of Munich, Geschwister-Scholl-Platz 1, 80539 Munich, Germany, Technical University of Munich, Boltzmannstr. 3, 85748 Garching, Germany
Martin Reinecke
Max Planck Institute for Astrophysics, Karl-Schwarzschild-Str. 1, 85748 Garching, Germany
Gordian Edenhofer ORCID
Max Planck Institute for Astrophysics, Karl-Schwarzschild-Str. 1, 85748 Garching, Germany, Ludwig Maximilian University of Munich, Geschwister-Scholl-Platz 1, 80539 Munich, Germany, Department of Astrophysics, University of Vienna, Türkenschanzstr. 17, A-1180 Vienna, Austria
Editor
Daniel S. Katz ORCID
Tags
Machine Learning High Performance Computing

GitHub Events

Total
  • Create event: 8
  • Release event: 2
  • Issues event: 3
  • Watch event: 9
  • Delete event: 4
  • Issue comment event: 29
  • Push event: 23
  • Pull request review event: 4
  • Pull request review comment event: 2
  • Pull request event: 13
Last Year
  • Create event: 8
  • Release event: 2
  • Issues event: 3
  • Watch event: 9
  • Delete event: 4
  • Issue comment event: 29
  • Push event: 23
  • Pull request review event: 4
  • Pull request review comment event: 2
  • Pull request event: 13

Committers

Last synced: 7 months ago

All Time
  • Total Commits: 336
  • Total Committers: 4
  • Avg Commits per committer: 84.0
  • Development Distribution Score (DDS): 0.563
Past Year
  • Commits: 21
  • Committers: 3
  • Avg Commits per committer: 7.0
  • Development Distribution Score (DDS): 0.524
Top Committers
Name Email Commits
Gordian Edenhofer g****r@g****m 147
Jakob Roth r****h@m****e 139
Martin Reinecke m****n@m****e 48
Daniel S. Katz d****z@i****g 2
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 14
  • Total pull requests: 31
  • Average time to close issues: 8 days
  • Average time to close pull requests: 1 day
  • Total issue authors: 7
  • Total pull request authors: 4
  • Average comments per issue: 3.5
  • Average comments per pull request: 1.9
  • Merged pull requests: 29
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 2
  • Pull requests: 6
  • Average time to close issues: 5 days
  • Average time to close pull requests: 2 days
  • Issue authors: 1
  • Pull request authors: 3
  • Average comments per issue: 4.0
  • Average comments per pull request: 4.5
  • Merged pull requests: 4
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • dfm (5)
  • roth-jakob (3)
  • wsmoses (2)
  • ddhendriks (1)
  • Edenhofer (1)
  • hawkinsp (1)
  • Joshuaalbert (1)
Pull Request Authors
  • roth-jakob (27)
  • Edenhofer (17)
  • mreineck (13)
  • danielskatz (2)
Top Labels
Issue Labels
documentation (1) good first issue (1)
Pull Request Labels

Packages

  • Total packages: 1
  • Total downloads:
    • pypi 3,136 last-month
  • Total dependent packages: 0
  • Total dependent repositories: 0
  • Total versions: 4
  • Total maintainers: 3
pypi.org: jaxbind

Bind any function written in another language to JAX with support for JVP/VJP/batching/jit compilation

  • Documentation: https://jaxbind.readthedocs.io/
  • License: BSD 2-Clause License Copyright (c) 2024, Max-Planck-Society Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  • Latest release: 1.2.1
    published 10 months ago
  • Versions: 4
  • Dependent Packages: 0
  • Dependent Repositories: 0
  • Downloads: 3,136 Last month
Rankings
Dependent packages count: 9.7%
Average: 36.9%
Dependent repos count: 64.1%
Maintainers (3)
Last synced: 6 months ago

Dependencies

pyproject.toml pypi
setup.py pypi
.github/actions/setup_package/action.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
.github/workflows/build_wheels.yml actions
  • actions/checkout v4 composite
  • actions/setup-python v5 composite
  • actions/upload-artifact v4 composite
  • pypa/cibuildwheel v2.17.0 composite
.github/workflows/test.yml actions
  • ./.github/actions/setup_package * composite
  • actions/checkout v4 composite
.github/workflows/pages.yml actions
  • ./.github/actions/setup_package * composite
  • actions/checkout v4 composite
  • actions/configure-pages v5 composite
  • actions/deploy-pages v4 composite
  • actions/upload-pages-artifact v3 composite