Skip to main content

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡

Project description

JOSS paper PyPI Version Conda Version Tests codecov Documentation Status

mpi4jax enables zero-copy, multi-host communication of JAX arrays, even from jitted code and from GPU memory.

But why?

The JAX framework has great performance for scientific computing workloads, but its multi-host capabilities are still limited.

With mpi4jax, you can scale your JAX-based simulations to entire CPU and GPU clusters (without ever leaving jax.jit).

In the spirit of differentiable programming, mpi4jax also supports differentiating through some MPI operations.

Installation

mpi4jax is available through pip and conda:

$ pip install mpi4jax                     # Pip
$ conda install -c conda-forge mpi4jax    # conda

Depending on the different jax backends you want to use, you can install mpi4jax in the following way

# pip install 'jax[cpu]'
$ pip install mpi4jax

# pip install -U 'jax[cuda12]'
$ pip install cython
$ pip install mpi4jax --no-build-isolation

# pip install -U 'jax[cuda12_local]'
$ CUDA_ROOT=XXX pip install mpi4jax

(for more informations on jax GPU distributions, see the JAX installation instructions)

In case your MPI installation is not detected correctly, it can help to install mpi4py separately. When using a pre-installed mpi4py, you must use --no-build-isolation when installing mpi4jax:

# if mpi4py is already installed
$ pip install cython
$ pip install mpi4jax --no-build-isolation

Our documentation includes some more advanced installation examples.

Example usage

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

@jax.jit
def foo(arr):
   arr = arr + rank
   arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
   return arr_sum

a = jnp.zeros((3, 3))
result = foo(a)

if rank == 0:
   print(result)

Running this script on 4 processes gives:

$ mpirun -n 4 python example.py
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

allreduce is just one example of the MPI primitives you can use. See all supported operations here.

Community guidelines

If you have a question or feature request, or want to report a bug, feel free to open an issue.

We welcome contributions of any kind through pull requests. For information on running our tests, debugging, and contribution guidelines please refer to the corresponding documentation page.

How to cite

If you use mpi4jax in your work, please consider citing the following article:

@article{mpi4jax,
  doi = {10.21105/joss.03419},
  url = {https://doi.org/10.21105/joss.03419},
  year = {2021},
  publisher = {The Open Journal},
  volume = {6},
  number = {65},
  pages = {3419},
  author = {Dion Häfner and Filippo Vicentini},
  title = {mpi4jax: Zero-copy MPI communication of JAX arrays},
  journal = {Journal of Open Source Software}
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mpi4jax-0.7.1.tar.gz (70.7 kB view details)

Uploaded Source

File details

Details for the file mpi4jax-0.7.1.tar.gz.

File metadata

  • Download URL: mpi4jax-0.7.1.tar.gz
  • Upload date:
  • Size: 70.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for mpi4jax-0.7.1.tar.gz
Algorithm Hash digest
SHA256 9c07a1739cd3362819c77f160c9bfe060bc1b1e2275a9cb996a4ec39e3d98e6f
MD5 cb87eb357d8da9a31bd9ef7902006067
BLAKE2b-256 70308e0471fc89456ebf90ea5c50307b04accbfd50dcaf4352fbd64b798e196d

See more details on using hashes here.

Provenance

The following attestation bundles were made for mpi4jax-0.7.1.tar.gz:

Publisher: python-publish.yml on mpi4jax/mpi4jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page