Numba @jittable MPI wrappers tested on Linux, macOS and Windows
Project description
numba-mpi
Overview
numba-mpi provides Python wrappers to the C MPI API callable from within Numba JIT-compiled code (@jit mode). For an outline of the project, rationale, architecture, and features, refer to: numba-mpi arXiv e-print (please cite if numba-mpi is used in your research).
Support is provided for a subset of MPI routines covering: size
/rank
, send
/recv
, allreduce
, bcast
, scatter
/gather
& allgather
, barrier
, wtime
and basic asynchronous communication with isend
/irecv
(only for contiguous arrays); for request handling including wait
/waitall
/waitany
and test
/testall
/testany
.
The API uses NumPy and supports both numeric and character datatypes (e.g., broadcast
).
Auto-generated docstring-based API docs are published on the web: https://numba-mpi.github.io/numba-mpi
Packages can be obtained from
PyPI,
Conda Forge,
Arch Linux
or by invoking pip install git+https://github.com/numba-mpi/numba-mpi.git
.
numba-mpi is a pure-Python package. The codebase includes a test suite used through the GitHub Actions workflows (thanks to mpi4py's setup-mpi!) for automated testing on: Linux (MPICH, OpenMPI & Intel MPI), macOS (MPICH & OpenMPI) and Windows (MS MPI).
Features that are not implemented yet include (help welcome!):
- support for non-default communicators
- support for
MPI_IN_PLACE
in[all]gather
/scatter
andallreduce
- support for
MPI_Type_create_struct
(Numpy structured arrays) - ...
Hello world send/recv example:
import numba, numba_mpi, numpy
@numba.jit()
def hello():
src = numpy.array([1., 2., 3., 4., 5.])
dst_tst = numpy.empty_like(src)
if numba_mpi.rank() == 0:
numba_mpi.send(src, dest=1, tag=11)
elif numba_mpi.rank() == 1:
numba_mpi.recv(dst_tst, source=0, tag=11)
hello()
Example comparing numba-mpi vs. mpi4py performance:
The example below compares Numba
+mpi4py
vs. Numba
+numba-mpi
performance.
The sample code estimates $\pi$ by numerical integration of $\int_0^1 (4/(1+x^2))dx=\pi$
dividing the workload into n_intervals
handled by separate MPI processes
and then obtaining a sum using allreduce
(see, e.g., analogous Matlab docs example).
The computation is carried out in a JIT-compiled function get_pi_part()
and is repeated
N_TIMES
. The repetitions and the MPI-handled reduction are done outside or
inside of the JIT-compiled block for mpi4py
and numba-mpi
, respectively.
Timing is repeated N_REPEAT
times and the minimum time is reported.
The generated plot shown below depicts the speedup obtained by replacing mpi4py
with numba_mpi
, plotted as a function of N_TIMES / n_intervals
- the number of MPI calls per
interval. The speedup, which stems from avoiding roundtrips between JIT-compiled
and Python code is significant (150%-300%) in all cases. The more often communication
is needed (smaller n_intervals
), the larger the measured speedup. Note that nothing
in the actual number crunching (within the get_pi_part()
function) or in the employed communication logic
(handled by the same MPI library) differs between the mpi4py
or numba-mpi
solutions.
These are the overhead of mpi4py
higher-level abstractions and the overhead of
repeatedly entering and leaving the JIT-compiled block if using mpi4py
, which can be
eliminated by using numba-mpi
, and which the measured differences in execution time
stem from.
import timeit, mpi4py, numba, numpy as np, numba_mpi
N_TIMES = 10000
RTOL = 1e-3
@numba.jit
def get_pi_part(n_intervals=1000000, rank=0, size=1):
h = 1 / n_intervals
partial_sum = 0.0
for i in range(rank + 1, n_intervals, size):
x = h * (i - 0.5)
partial_sum += 4 / (1 + x**2)
return h * partial_sum
@numba.jit
def pi_numba_mpi(n_intervals):
pi = np.array([0.])
part = np.empty_like(pi)
for _ in range(N_TIMES):
part[0] = get_pi_part(n_intervals, numba_mpi.rank(), numba_mpi.size())
numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)
assert abs(pi[0] - np.pi) / np.pi < RTOL
def pi_mpi4py(n_intervals):
pi = np.array([0.])
part = np.empty_like(pi)
for _ in range(N_TIMES):
part[0] = get_pi_part(n_intervals, mpi4py.MPI.COMM_WORLD.rank, mpi4py.MPI.COMM_WORLD.size)
mpi4py.MPI.COMM_WORLD.Allreduce(part, (pi, mpi4py.MPI.DOUBLE), op=mpi4py.MPI.SUM)
assert abs(pi[0] - np.pi) / np.pi < RTOL
plot_x = [x for x in range(1, 11)]
plot_y = {'numba_mpi': [], 'mpi4py': []}
for x in plot_x:
for impl in plot_y:
plot_y[impl].append(min(timeit.repeat(
f"pi_{impl}(n_intervals={N_TIMES // x})",
globals=locals(),
number=1,
repeat=10
)))
if numba_mpi.rank() == 0:
from matplotlib import pyplot
pyplot.figure(figsize=(8.3, 3.5), tight_layout=True)
pyplot.plot(plot_x, np.array(plot_y['mpi4py'])/np.array(plot_y['numba_mpi']), marker='o')
pyplot.xlabel('number of MPI calls per interval')
pyplot.ylabel('mpi4py/numba-mpi wall-time ratio')
pyplot.title(f'mpiexec -np {numba_mpi.size()}')
pyplot.grid()
pyplot.savefig('readme_plot.svg')
MPI resources on the web:
- MPI standard and general information:
- MPI implementations:
- MPI bindings:
Acknowledgements:
We thank all contributors and users who reported feedback to the project through GitHub issues.
Development of numba-mpi has been supported by the Polish National Science Centre (grant no. 2020/39/D/ST10/01220), the Max Planck Society and the European Union (ERC, EmulSim, 101044662). We further acknowledge Poland’s high-performance computing infrastructure PLGrid (HPC Centers: ACK Cyfronet AGH) for providing computer facilities and support within computational grant no. PLG/2023/016369.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for numba_mpi-1.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fae811398be881d25f7d558879f974b21166f81881b9bc3bc5e3f70d561d68bd |
|
MD5 | 3bed0373ee1e0eff1c47bce6bb6fa3ba |
|
BLAKE2b-256 | 8d243a9aeae3e518586616cc3c8356cb08cf7fb03ef0b5a6ba5e9ccac87e5ff1 |