mpi4jax

mpi4jax: Zero-copy MPI communication of JAX arrays - Published in JOSS (2021)

https://github.com/mpi4jax/mpi4jax

Science Score: 100.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
    Found 7 DOI reference(s) in README and JOSS metadata
  • Academic publication links
    Links to: joss.theoj.org
  • Committers with academic emails
    1 of 7 committers (14.3%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
    Published in Journal of Open Source Software

Keywords

gpu high-performance-computing jax jit mpi parallel-computing xla

Keywords from Contributors

mesh simulations
Last synced: 4 months ago · JSON representation ·

Repository

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python :zap:

Basic Info
Statistics
  • Stars: 491
  • Watchers: 10
  • Forks: 32
  • Open Issues: 23
  • Releases: 85
Topics
gpu high-performance-computing jax jit mpi parallel-computing xla
Created over 5 years ago · Last pushed 4 months ago
Metadata Files
Readme License Citation

README.rst

mpi4jax
=======

|JOSS paper| |PyPI Version| |Conda Version| |Tests| |codecov| |Documentation Status|

``mpi4jax`` enables zero-copy, multi-host communication of `JAX `_ arrays, even from jitted code and from GPU memory.


But why?
--------

The JAX framework `has great performance for scientific computing workloads `_, but its `multi-host capabilities `_ are still limited.

With ``mpi4jax``, you can scale your JAX-based simulations to *entire CPU and GPU clusters* (without ever leaving ``jax.jit``).

In the spirit of differentiable programming, ``mpi4jax`` also supports differentiating through some MPI operations.


Installation
------------

``mpi4jax`` is available through ``pip`` and ``conda``:

.. code:: bash

   $ pip install mpi4jax                     # Pip
   $ conda install -c conda-forge mpi4jax    # conda

Depending on the different jax backends you want to use, you can install mpi4jax in the following way

.. code:: bash

   # pip install 'jax[cpu]'
   $ pip install mpi4jax

   # pip install -U 'jax[cuda12]'
   $ pip install cython
   $ pip install mpi4jax --no-build-isolation

   # pip install -U 'jax[cuda12_local]'
   $ CUDA_ROOT=XXX pip install mpi4jax

(for more informations on jax GPU distributions, `see the JAX installation instructions `_)

In case your MPI installation is not detected correctly, `it can help to install mpi4py separately `_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:

.. code:: bash

   # if mpi4py is already installed
   $ pip install cython
   $ pip install mpi4jax --no-build-isolation

`Our documentation includes some more advanced installation examples. `_


Example usage
-------------

.. code:: python

   from mpi4py import MPI
   import jax
   import jax.numpy as jnp
   import mpi4jax

   comm = MPI.COMM_WORLD
   rank = comm.Get_rank()

   @jax.jit
   def foo(arr):
      arr = arr + rank
      arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
      return arr_sum

   a = jnp.zeros((3, 3))
   result = foo(a)

   if rank == 0:
      print(result)

Running this script on 4 processes gives:

.. code:: bash

   $ mpirun -n 4 python example.py
   [[6. 6. 6.]
    [6. 6. 6.]
    [6. 6. 6.]]

``allreduce`` is just one example of the MPI primitives you can use. `See all supported operations here `_.

Community guidelines
--------------------

If you have a question or feature request, or want to report a bug, feel free to `open an issue `_.

We welcome contributions of any kind `through pull requests `_. For information on running our tests, debugging, and contribution guidelines please `refer to the corresponding documentation page `_.

How to cite
-----------

If you use ``mpi4jax`` in your work, please consider citing the following article:

::

  @article{mpi4jax,
    doi = {10.21105/joss.03419},
    url = {https://doi.org/10.21105/joss.03419},
    year = {2021},
    publisher = {The Open Journal},
    volume = {6},
    number = {65},
    pages = {3419},
    author = {Dion Häfner and Filippo Vicentini},
    title = {mpi4jax: Zero-copy MPI communication of JAX arrays},
    journal = {Journal of Open Source Software}
  }

.. |Tests| image:: https://github.com/mpi4jax/mpi4jax/workflows/Tests/badge.svg
   :target: https://github.com/mpi4jax/mpi4jax/actions?query=branch%3Amain
.. |codecov| image:: https://codecov.io/gh/mpi4jax/mpi4jax/branch/main/graph/badge.svg
   :target: https://codecov.io/gh/mpi4jax/mpi4jax
.. |PyPI Version| image:: https://img.shields.io/pypi/v/mpi4jax
   :target: https://pypi.org/project/mpi4jax/
.. |Conda Version| image:: https://img.shields.io/conda/vn/conda-forge/mpi4jax.svg
   :target: https://anaconda.org/conda-forge/mpi4jax
.. |Documentation Status| image:: https://readthedocs.org/projects/mpi4jax/badge/?version=latest
   :target: https://mpi4jax.readthedocs.io/en/latest/?badge=latest
.. |JOSS paper| image:: https://joss.theoj.org/papers/10.21105/joss.03419/status.svg
   :target: https://doi.org/10.21105/joss.03419

Owner

  • Name: mpi4jax
  • Login: mpi4jax
  • Kind: organization

JOSS Publication

mpi4jax: Zero-copy MPI communication of JAX arrays
Published
September 01, 2021
Volume 6, Issue 65, Page 3419
Authors
Dion Häfner ORCID
Niels Bohr Institute, University of Copenhagen, Copenhagen, Denmark
Filippo Vicentini
Institute of Physics, École Polytechnique Fédérale de Lausanne (EPFL), CH-1015 Lausanne, Switzerland
Editor
Kelly Rowland ORCID
Tags
Python JAX MPI high performance computing parallel computing

Citation (CITATION.cff)

authors:
  - family-names: Vicentini
    given-names: Filippo
  - family-names: Häfner
    given-names: Dion
cff-version: 1.2.0
message: "If you use this software, please cite the article from preferred-citation."
title: mpi4jax
url: "https://github.com/mpi4jax/mpi4jax"
preferred-citation:
  type: article
  title: "mpi4jax: Zero-copy MPI communication of JAX arrays"
  authors:
    - family-names: Häfner
      given-names: Dion
    - family-names: Vicentini
      given-names: Filippo
  year: 2021
  journal: "Journal of Open Source Software"
  volume: 6
  issue: 65
  pages: "3419"
  doi: "10.21105/joss.03419"
  url: "https://doi.org/10.21105/joss.03419"

GitHub Events

Total
  • Create event: 18
  • Release event: 9
  • Issues event: 8
  • Watch event: 52
  • Delete event: 12
  • Issue comment event: 45
  • Push event: 25
  • Pull request review event: 11
  • Pull request review comment event: 9
  • Pull request event: 19
  • Fork event: 2
Last Year
  • Create event: 18
  • Release event: 9
  • Issues event: 8
  • Watch event: 52
  • Delete event: 12
  • Issue comment event: 45
  • Push event: 25
  • Pull request review event: 11
  • Pull request review comment event: 9
  • Pull request event: 19
  • Fork event: 2

Committers

Last synced: 5 months ago

All Time
  • Total Commits: 298
  • Total Committers: 7
  • Avg Commits per committer: 42.571
  • Development Distribution Score (DDS): 0.601
Past Year
  • Commits: 21
  • Committers: 4
  • Avg Commits per committer: 5.25
  • Development Distribution Score (DDS): 0.476
Top Committers
Name Email Commits
Dion Häfner d****r@n****k 119
Filippo Vicentini f****i@g****m 102
dependabot[bot] 4****] 72
Clemens Giuliani i****g@g****m 2
Marjan Macek 3****r 1
Jacek Czaja j****a@i****m 1
Chase Roberts c****s@g****m 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 46
  • Total pull requests: 138
  • Average time to close issues: about 2 months
  • Average time to close pull requests: 8 days
  • Total issue authors: 28
  • Total pull request authors: 11
  • Average comments per issue: 4.96
  • Average comments per pull request: 2.53
  • Merged pull requests: 116
  • Bot issues: 0
  • Bot pull requests: 78
Past Year
  • Issues: 9
  • Pull requests: 22
  • Average time to close issues: 4 days
  • Average time to close pull requests: 8 days
  • Issue authors: 8
  • Pull request authors: 4
  • Average comments per issue: 3.22
  • Average comments per pull request: 1.23
  • Merged pull requests: 19
  • Bot issues: 0
  • Bot pull requests: 12
Top Authors
Issue Authors
  • PhilipVinc (11)
  • dionhaefner (4)
  • ntlm1686 (3)
  • benkirk (2)
  • Thenerdstation (2)
  • coreyjadams (2)
  • shyams2 (2)
  • Joshuaalbert (1)
  • louiskirsch (1)
  • brentmorgan1987 (1)
  • Zantares (1)
  • jwnys (1)
  • mtagliazucchi (1)
  • halvarsu (1)
  • henryiii (1)
Pull Request Authors
  • dependabot[bot] (88)
  • dionhaefner (35)
  • PhilipVinc (33)
  • jczaja (3)
  • macekmar (2)
  • Thenerdstation (2)
  • EiffL (1)
  • nutrik (1)
  • henrique (1)
  • wdphy16 (1)
  • felker (1)
Top Labels
Issue Labels
enhancement (6) help wanted (2) bug (2) question (1)
Pull Request Labels
dependencies (88) python (3)

Packages

  • Total packages: 4
  • Total downloads:
    • pypi 604 last-month
  • Total dependent packages: 3
    (may contain duplicates)
  • Total dependent repositories: 1
    (may contain duplicates)
  • Total versions: 165
  • Total maintainers: 3
proxy.golang.org: github.com/mpi4jax/mpi4jax
  • Versions: 49
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 6.5%
Average: 6.7%
Dependent repos count: 6.9%
Last synced: 4 months ago
pypi.org: mpi4jax

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡

  • Versions: 86
  • Dependent Packages: 2
  • Dependent Repositories: 1
  • Downloads: 604 Last month
  • Docker Downloads: 0
Rankings
Dependent packages count: 2.1%
Docker downloads count: 2.2%
Stargazers count: 3.6%
Average: 7.8%
Downloads: 8.2%
Forks count: 8.4%
Dependent repos count: 22.1%
Maintainers (2)
Last synced: 4 months ago
spack.io: py-mpi4jax

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python.

  • Versions: 1
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent repos count: 0.0%
Stargazers count: 13.6%
Forks count: 21.7%
Average: 23.2%
Dependent packages count: 57.3%
Maintainers (1)
Last synced: 4 months ago
conda-forge.org: mpi4jax

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡

  • Versions: 29
  • Dependent Packages: 1
  • Dependent Repositories: 0
Rankings
Stargazers count: 23.0%
Dependent packages count: 28.8%
Average: 30.6%
Dependent repos count: 34.0%
Forks count: 36.7%
Last synced: 4 months ago

Dependencies

docs/environment.yml pypi
  • jax *