Numba @jittable MPI wrappers tested on Linux, macOS and Windows
Project description
numba-mpi
Overview
numba-mpi provides Python wrappers to the C MPI API callable from within Numba JIT-compiled code (@jit mode). For an outline of the project, rationale, architecture, and features, refer to: numba-mpi arXiv e-print (please cite if numba-mpi is used in your research).
Support is provided for a subset of MPI routines covering: size
/rank
, send
/recv
, allreduce
, reduce
, bcast
, scatter
/gather
& allgather
, barrier
, wtime
and basic asynchronous communication with isend
/irecv
(only for contiguous arrays); for request handling including wait
/waitall
/waitany
and test
/testall
/testany
.
The API uses NumPy and supports both numeric and character datatypes (e.g., broadcast
).
Auto-generated docstring-based API docs are published on the web: https://numba-mpi.github.io/numba-mpi
Packages can be obtained from
PyPI,
Conda Forge,
Arch Linux
or by invoking pip install git+https://github.com/numba-mpi/numba-mpi.git
.
numba-mpi is a pure-Python package. The codebase includes a test suite used through the GitHub Actions workflows (thanks to mpi4py's setup-mpi!) for automated testing on: Linux (MPICH, OpenMPI & Intel MPI), macOS (MPICH & OpenMPI) and Windows (MS MPI).
Features that are not implemented yet include (help welcome!):
- support for non-default communicators
- support for
MPI_IN_PLACE
in[all]gather
/scatter
andallreduce
- support for
MPI_Type_create_struct
(Numpy structured arrays) - ...
Hello world send/recv example:
import numba, numba_mpi, numpy
@numba.jit()
def hello():
src = numpy.array([1., 2., 3., 4., 5.])
dst_tst = numpy.empty_like(src)
if numba_mpi.rank() == 0:
numba_mpi.send(src, dest=1, tag=11)
elif numba_mpi.rank() == 1:
numba_mpi.recv(dst_tst, source=0, tag=11)
hello()
Example comparing numba-mpi vs. mpi4py performance:
The example below compares Numba
+mpi4py
vs. Numba
+numba-mpi
performance.
The sample code estimates $\pi$ by numerical integration of $\int_0^1 (4/(1+x^2))dx=\pi$
dividing the workload into n_intervals
handled by separate MPI processes
and then obtaining a sum using allreduce
(see, e.g., analogous Matlab docs example).
The computation is carried out in a JIT-compiled function get_pi_part()
and is repeated
N_TIMES
. The repetitions and the MPI-handled reduction are done outside or
inside of the JIT-compiled block for mpi4py
and numba-mpi
, respectively.
Timing is repeated N_REPEAT
times and the minimum time is reported.
The generated plot shown below depicts the speedup obtained by replacing mpi4py
with numba_mpi
, plotted as a function of N_TIMES / n_intervals
- the number of MPI calls per
interval. The speedup, which stems from avoiding roundtrips between JIT-compiled
and Python code is significant (150%-300%) in all cases. The more often communication
is needed (smaller n_intervals
), the larger the measured speedup. Note that nothing
in the actual number crunching (within the get_pi_part()
function) or in the employed communication logic
(handled by the same MPI library) differs between the mpi4py
or numba-mpi
solutions.
These are the overhead of mpi4py
higher-level abstractions and the overhead of
repeatedly entering and leaving the JIT-compiled block if using mpi4py
, which can be
eliminated by using numba-mpi
, and which the measured differences in execution time
stem from.
import timeit, mpi4py, numba, numpy as np, numba_mpi
N_TIMES = 10000
RTOL = 1e-3
@numba.jit
def get_pi_part(n_intervals=1000000, rank=0, size=1):
h = 1 / n_intervals
partial_sum = 0.0
for i in range(rank + 1, n_intervals, size):
x = h * (i - 0.5)
partial_sum += 4 / (1 + x**2)
return h * partial_sum
@numba.jit
def pi_numba_mpi(n_intervals):
pi = np.array([0.])
part = np.empty_like(pi)
for _ in range(N_TIMES):
part[0] = get_pi_part(n_intervals, numba_mpi.rank(), numba_mpi.size())
numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)
assert abs(pi[0] - np.pi) / np.pi < RTOL
def pi_mpi4py(n_intervals):
pi = np.array([0.])
part = np.empty_like(pi)
for _ in range(N_TIMES):
part[0] = get_pi_part(n_intervals, mpi4py.MPI.COMM_WORLD.rank, mpi4py.MPI.COMM_WORLD.size)
mpi4py.MPI.COMM_WORLD.Allreduce(part, (pi, mpi4py.MPI.DOUBLE), op=mpi4py.MPI.SUM)
assert abs(pi[0] - np.pi) / np.pi < RTOL
plot_x = [x for x in range(1, 11)]
plot_y = {'numba_mpi': [], 'mpi4py': []}
for x in plot_x:
for impl in plot_y:
plot_y[impl].append(min(timeit.repeat(
f"pi_{impl}(n_intervals={N_TIMES // x})",
globals=locals(),
number=1,
repeat=10
)))
if numba_mpi.rank() == 0:
from matplotlib import pyplot
pyplot.figure(figsize=(8.3, 3.5), tight_layout=True)
pyplot.plot(plot_x, np.array(plot_y['mpi4py'])/np.array(plot_y['numba_mpi']), marker='o')
pyplot.xlabel('number of MPI calls per interval')
pyplot.ylabel('mpi4py/numba-mpi wall-time ratio')
pyplot.title(f'mpiexec -np {numba_mpi.size()}')
pyplot.grid()
pyplot.savefig('readme_plot.svg')
MPI resources on the web:
- MPI standard and general information:
- MPI implementations:
- MPI bindings:
Acknowledgements:
We thank all contributors and users who reported feedback to the project through GitHub issues.
Development of numba-mpi has been supported by the Polish National Science Centre (grant no. 2020/39/D/ST10/01220), the Max Planck Society and the European Union (ERC, EmulSim, 101044662). We further acknowledge Poland’s high-performance computing infrastructure PLGrid (HPC Centers: ACK Cyfronet AGH) for providing computer facilities and support within computational grant no. PLG/2023/016369.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file numba_mpi-1.1.1.tar.gz
.
File metadata
- Download URL: numba_mpi-1.1.1.tar.gz
- Upload date:
- Size: 51.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8fc61d1c8ae27d62899d7c8abe98377fd63ee726dcc8767d5dc6a760acea7357 |
|
MD5 | 929280d86149605855cfccd620e2f696 |
|
BLAKE2b-256 | 5b54cfa33be5ef98f2d22b00eefe2bc4f2ec1a5806221e7a352a55bbc9f5f79c |
File details
Details for the file numba_mpi-1.1.1-py3-none-any.whl
.
File metadata
- Download URL: numba_mpi-1.1.1-py3-none-any.whl
- Upload date:
- Size: 30.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | db6c343ea50bcd6712e09d5df5350572879034618c370403b6450f6f1287cc05 |
|
MD5 | bb02ee21d05c01fb742d8beda0d3ab39 |
|
BLAKE2b-256 | bca9086fd9bea3d0476fe6eca81b52200e14da23dc49ecd90c8a491df96de09b |