No project description provided
Project description
mpi4jax enables zero-copy, multi-host communication of JAX arrays, even from jitted code and from GPU memory.
But why?
The JAX framework has great performance for scientific computing workloads, but its multi-host capabilities are still limited.
With mpi4jax, you can scale your JAX-based simulations to entire CPU and GPU clusters (without ever leaving jax.jit).
In the spirit of differentiable programming, mpi4jax also supports differentiating through some MPI operations.
Quick installation
mpi4jax is available through pip and conda:
$ pip install mpi4jax # Pip
$ conda install -c conda-forge mpi4jax # conda
Our documentation includes some more advanced installation examples.
Example usage
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.Allreduce(arr, op=MPI.SUM, comm=comm)
return arr_sum
a = jnp.zeros((3, 3))
result = foo(a)
if rank == 0:
print(result)
Running this script on 4 processes gives:
$ mpirun -n 4 python example.py
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
Allreduce is just one example of the MPI primitives you can use. See all supported operations here.
Contributing
We use pre-commit hooks to enforce a common code format. To install them, just run:
$ pip install pre-commit
$ pre-commit install
Debugging
You can set the environment variable MPI4JAX_DEBUG to 1 to enable debug logging every time an MPI primitive is called from within a jitted function. You will then see messages like this:
$ MPI4JAX_DEBUG=1 mpirun -n 2 python send_recv.py
r0 | MPI_Send -> 1 with tag 0 and token 7fd7abc5f5c0
r1 | MPI_Recv <- 0 with tag -1 and token 7f9af7419ac0
Contributors
Filippo Vicentini @PhilipVinc
Dion Häfner @dionhaefner
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.