Skip to main content

Numba @jittable MPI wrappers tested on Linux, macOS and Windows

Project description

numba-mpi logo numba-mpi

Python 3 LLVM Linux OK macOS OK Windows OK Github Actions Status Maintenance License: GPL v3 PyPI version Anaconda-Server Badge AUR package DOI

Overview

numba-mpi provides Python wrappers to the C MPI API callable from within Numba JIT-compiled code (@jit mode). For an outline of the project, rationale, architecture, and features, refer to: numba-mpi paper in SoftwareX (open access) (please cite if numba-mpi is used in your research).

Support is provided for a subset of MPI routines covering: size/rank, send/recv, allreduce, reduce, bcast, scatter/gather & allgather, barrier, wtime and basic asynchronous communication with isend/irecv (only for contiguous arrays); for request handling including wait/waitall/waitany and test/testall/testany.

The API uses NumPy and supports both numeric and character datatypes (e.g., broadcast). Auto-generated docstring-based API docs are published on the web: https://numba-mpi.github.io/numba-mpi

Packages can be obtained from PyPI, Conda Forge, Arch Linux or by invoking pip install git+https://github.com/numba-mpi/numba-mpi.git.

numba-mpi is a pure-Python package. The codebase includes a test suite used through the GitHub Actions workflows (thanks to mpi4py's setup-mpi!) for automated testing on: Linux (MPICH, OpenMPI & Intel MPI), macOS (MPICH & OpenMPI) and Windows (MS MPI). Note, that some of those combinations may not be fully supported yet - see Known Issues for more information.

Features that are not implemented yet include (help welcome!):

  • support for non-default communicators
  • support for MPI_IN_PLACE in [all]gather/scatter and allreduce
  • support for MPI_Type_create_struct (Numpy structured arrays)
  • ...

Hello world send/recv example:

import numba, numba_mpi, numpy

@numba.jit()
def hello():
    src = numpy.array([1., 2., 3., 4., 5.])
    dst_tst = numpy.empty_like(src)

    if numba_mpi.rank() == 0:
        numba_mpi.send(src, dest=1, tag=11)
    elif numba_mpi.rank() == 1:
        numba_mpi.recv(dst_tst, source=0, tag=11)

hello()

Example comparing numba-mpi vs. mpi4py performance:

The example below compares Numba+mpi4py vs. Numba+numba-mpi performance. The sample code estimates $\pi$ by numerical integration of $\int_0^1 (4/(1+x^2))dx=\pi$ dividing the workload into n_intervals handled by separate MPI processes and then obtaining a sum using allreduce (see, e.g., analogous Matlab docs example). The computation is carried out in a JIT-compiled function get_pi_part() and is repeated N_TIMES. The repetitions and the MPI-handled reduction are done outside or inside of the JIT-compiled block for mpi4py and numba-mpi, respectively. Timing is repeated N_REPEAT times and the minimum time is reported. The generated plot shown below depicts the speedup obtained by replacing mpi4py with numba_mpi, plotted as a function of N_TIMES / n_intervals - the number of MPI calls per interval. The speedup, which stems from avoiding roundtrips between JIT-compiled and Python code is significant (150%-300%) in all cases. The more often communication is needed (smaller n_intervals), the larger the measured speedup. Note that nothing in the actual number crunching (within the get_pi_part() function) or in the employed communication logic (handled by the same MPI library) differs between the mpi4py or numba-mpi solutions. These are the overhead of mpi4py higher-level abstractions and the overhead of repeatedly entering and leaving the JIT-compiled block if using mpi4py, which can be eliminated by using numba-mpi, and which the measured differences in execution time stem from.

import timeit, mpi4py, numba, numpy as np, numba_mpi

N_TIMES = 10000
RTOL = 1e-3

@numba.jit
def get_pi_part(n_intervals=1000000, rank=0, size=1):
    h = 1 / n_intervals
    partial_sum = 0.0
    for i in range(rank + 1, n_intervals, size):
        x = h * (i - 0.5)
        partial_sum += 4 / (1 + x**2)
    return h * partial_sum

@numba.jit
def pi_numba_mpi(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        part[0] = get_pi_part(n_intervals, numba_mpi.rank(), numba_mpi.size())
        numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

def pi_mpi4py(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        part[0] = get_pi_part(n_intervals, mpi4py.MPI.COMM_WORLD.rank, mpi4py.MPI.COMM_WORLD.size)
        mpi4py.MPI.COMM_WORLD.Allreduce(part, (pi, mpi4py.MPI.DOUBLE), op=mpi4py.MPI.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

plot_x = [x for x in range(1, 11)]
plot_y = {'numba_mpi': [], 'mpi4py': []}
for x in plot_x:
    for impl in plot_y:
        plot_y[impl].append(min(timeit.repeat(
            f"pi_{impl}(n_intervals={N_TIMES // x})",
            globals=locals(),
            number=1,
            repeat=10
        )))

if numba_mpi.rank() == 0:
    from matplotlib import pyplot
    pyplot.figure(figsize=(8.3, 3.5), tight_layout=True)
    pyplot.plot(plot_x, np.array(plot_y['mpi4py'])/np.array(plot_y['numba_mpi']), marker='o')
    pyplot.xlabel('number of MPI calls per interval')
    pyplot.ylabel('mpi4py/numba-mpi wall-time ratio')
    pyplot.title(f'mpiexec -np {numba_mpi.size()}')
    pyplot.grid()
    pyplot.savefig('readme_plot.svg')

plot

Known Issues

NOTE: Issues listed below only relate to combinations of platforms and MPI distributions that we target to support, but due to various reason are currently not working and are temporarily excluded from automated testing:

  • tests on Ubuntu 2024.4 that use MPICH are not run due to failures caused by newer version of MPICH (4.2.0); note, that previous tests ran using version 4.0.2 of MPICH (that is installed by default on Ubuntu 2022.4 using apt) were passing (see related issue - TODO #162),
  • tests on Intel MacOS (v13) that use OpenMPI are currently not run due to failures being under investigation (see related issue - TODO #163),
  • numba-mpi currently does not support ARM-based MacOS, due to required code improvement (see related issue - TODO #164).

MPI resources on the web:

Acknowledgements:

We thank all contributors and users who reported feedback to the project through GitHub issues.

Development of numba-mpi has been supported by the Polish National Science Centre (grant no. 2020/39/D/ST10/01220), the Max Planck Society and the European Union (ERC, EmulSim, 101044662). We further acknowledge Poland’s high-performance computing infrastructure PLGrid (HPC Centers: ACK Cyfronet AGH) for providing computer facilities and support within computational grant no. PLG/2023/016369.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numba_mpi-1.1.5.tar.gz (53.3 kB view details)

Uploaded Source

Built Distribution

numba_mpi-1.1.5-py3-none-any.whl (31.1 kB view details)

Uploaded Python 3

File details

Details for the file numba_mpi-1.1.5.tar.gz.

File metadata

  • Download URL: numba_mpi-1.1.5.tar.gz
  • Upload date:
  • Size: 53.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for numba_mpi-1.1.5.tar.gz
Algorithm Hash digest
SHA256 5b7b7c0cde9ea04fb6d0751e1f2106d248a2ca9dedc2721cefdea9ff8dc31bdc
MD5 828aaa3210661e02a3d12d8bec06e609
BLAKE2b-256 fd51cb30201bdd87acf52eebe071984145fb6c5f10c04141c7e46fe081e6ed01

See more details on using hashes here.

File details

Details for the file numba_mpi-1.1.5-py3-none-any.whl.

File metadata

  • Download URL: numba_mpi-1.1.5-py3-none-any.whl
  • Upload date:
  • Size: 31.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for numba_mpi-1.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 ac6d0bc3f5e930b3c862160186d43274c254a55143daed884688cc9bdd6d3168
MD5 9ebaea851dfaf92dde78763e53e64cd2
BLAKE2b-256 5a91ca971340b3dad88fae9f48a932a9629ce31db4d9446c53e6043048b79369

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page