Skip to main content

Numba @njittable MPI wrappers tested on Linux, macOS and Windows

Project description

numba-mpi logo numba-mpi

Python 3 LLVM Linux OK macOS OK Windows OK Github Actions Status Maintenance License: GPL v3 PyPI version Anaconda-Server Badge AUR package DOI

Overview

numba-mpi provides Python wrappers to the C MPI API callable from within Numba JIT-compiled code (@njit mode).

Support is provided for a subset of MPI routines covering: size/rank, send/recv, allreduce, bcast, scatter/gather & allgather, barrier, wtime and basic asynchronous communication with isend/irecv (only for contiguous arrays); for request handling including wait/waitall/waitany and test/testall/testany.

The API uses NumPy and supports both numeric and character datatypes (e.g., broadcast). Auto-generated docstring-based API docs are published on the web: https://numba-mpi.github.io/numba-mpi

Packages can be obtained from PyPI, Conda Forge, Arch Linux or by invoking pip install git+https://github.com/numba-mpi/numba-mpi.git.

numba-mpi is a pure-Python package. The codebase includes a test suite used through the GitHub Actions workflows (thanks to mpi4py's setup-mpi!) for automated testing on: Linux (MPICH, OpenMPI & Intel MPI), macOS (MPICH & OpenMPI) and Windows (MS MPI).

Features that are not implemented yet include (help welcome!):

  • support for non-default communicators
  • support for MPI_IN_PLACE in [all]gather/scatter and allreduce
  • support for MPI_Type_create_struct (Numpy structured arrays)
  • ...

Hello world send/recv example:

import numba, numba_mpi, numpy

@numba.njit()
def hello():
    src = numpy.array([1., 2., 3., 4., 5.])
    dst_tst = numpy.empty_like(src)

    if numba_mpi.rank() == 0:
        numba_mpi.send(src, dest=1, tag=11)
    elif numba_mpi.rank() == 1:
        numba_mpi.recv(dst_tst, source=0, tag=11)

hello()

Example comparing numba-mpi vs. mpi4py performance:

The example below compares Numba + mpi4py vs. Numba + numba-mpi performance. The sample code estimates $\pi$ by integration of $4/(1+x^2)$ between 0 and 1 dividing the workload into n_intervals handled by separate MPI processes and then obtaining a sum using allreduce. The computation is carried out in a JIT-compiled function and is repeated N_TIMES, the repetitions and the MPI-handled reduction are done outside or inside of the JIT-compiled block for mpi4py and numba-mpi, respectively. Timing is repeated N_REPEAT times and the minimum time is reported. The generated plot shown below depicts the speedup obtained by replacing mpi4py with numba_mpi as a function of n_intervals - the more often communication is needed (smaller n_intervals), the larger the expected speedup.

import timeit, mpi4py, numba, numpy as np, numba_mpi

N_TIMES = 10000
N_REPEAT = 10
RTOL = 1e-3

@numba.njit
def get_pi_part(out, n_intervals, rank, size):
    h = 1 / n_intervals
    partial_sum = 0.0
    for i in range(rank + 1, n_intervals, size):
        x = h * (i - 0.5)
        partial_sum += 4 / (1 + x**2)
    out[0] = h * partial_sum

@numba.njit
def pi_numba_mpi(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        get_pi_part(part, n_intervals, numba_mpi.rank(), numba_mpi.size())
        numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

def pi_mpi4py(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        get_pi_part(part, n_intervals, mpi4py.MPI.COMM_WORLD.rank, mpi4py.MPI.COMM_WORLD.size)
        mpi4py.MPI.COMM_WORLD.Allreduce(part, (pi, mpi4py.MPI.DOUBLE), op=mpi4py.MPI.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

plot_x = [1000 * k for k in range(1, 11)]
plot_y = {'numba_mpi': [], 'mpi4py': []}
for n_intervals in plot_x:
    for impl in plot_y:
        plot_y[impl].append(min(timeit.repeat(
            f"pi_{impl}({n_intervals})",
            globals=locals(),
            number=1,
            repeat=N_REPEAT
        )))

if numba_mpi.rank() == 0:
    from matplotlib import pyplot
    pyplot.figure(figsize=(8.3, 3.5), tight_layout=True)
    pyplot.plot(plot_x, np.array(plot_y['mpi4py'])/np.array(plot_y['numba_mpi']), marker='o')
    pyplot.xlabel('n_intervals (workload in between communication)')
    pyplot.ylabel('wall time ratio (mpi4py / numba_mpi)')
    pyplot.title(f'mpiexec -np {numba_mpi.size()}')
    pyplot.grid()
    pyplot.savefig('readme_plot.png')

plot

MPI resources on the web:

Acknowledgements:

Development of numba-mpi has been supported by the Polish National Science Centre (grant no. 2020/39/D/ST10/01220).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numba_mpi-0.41.tar.gz (37.7 kB view hashes)

Uploaded Source

Built Distribution

numba_mpi-0.41-py3-none-any.whl (28.2 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page