Skip to main content

JAX bindings for the cuDecomp library

Project description

jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs

Build Code Formatting Tests MIT License Documentation DOI

Important Version 0.2.0 includes a pure JAX backend that no longer requires MPI. For multi-node runs, MPI and NCCL backends are still available through cuDecomp.

JAX reimplementation and bindings for NVIDIA's cuDecomp library (Romero et al. 2022), enabling multi-node parallel FFTs and halo exchanges directly in low-level NCCL/CUDA-Aware MPI from your JAX code.

Important Starting from version 0.2.8, jaxDecomp supports JAX's Shardy partitioner, which can be activated via jax.config.update('jax_use_shardy_partitioner', True). This partitioner is enabled by default in JAX 0.7.x and later versions. Shardy support is an internal implementation change and users should not expect any behavioral differences outside of what the JAX sharding mechanism provides, as explained in the JAX Shardy migration documentation.


Usage

Below is a simple code snippet illustrating how to perform a 3D FFT on a distributed 3D array, followed by a halo exchange. For demonstration purposes, we force 8 CPU devices via environment variables:

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding, AxisType
import jaxdecomp

# Create a 2x4 mesh of devices on CPU
pdims = (2, 4)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y') , axis_types=(AxisType.Auto, AxisType.Auto))
sharding = NamedSharding(mesh, P('x', 'y'))

# Create a random 3D array and enforce sharding
a = jax.random.normal(jax.random.PRNGKey(0), (1024, 1024, 1024))
a = jax.lax.with_sharding_constraint(a, sharding)

# Parallel FFTs
k_array = jaxdecomp.fft.pfft3d(a)
rec_array = jaxdecomp.fft.pifft3d(a)

# Parallel halo exchange
exchanged = jaxdecomp.halo_exchange(a, halo_extents=(16, 16), halo_periods=(True, True))

All these functions are JIT-compatible and support automatic differentiation (with some caveats).

See also:

Important Multi-node FFTs work with both JAX and cuDecomp backends
For CPU with JAX, Multi-node is supported starting JAX v0.5.1 (with gloo backend)


Running on an HPC Cluster

On HPC clusters (e.g., Jean Zay, Perlmutter), you typically launch your script with:

srun python your_script.py

or

mpirun -n 8 python your_script.py

See the Slurm README and template script for more details.


Using cuDecomp (MPI and NCCL)

For other features, compile and install with cuDecomp enabled as described in install:

import jaxdecomp

# Optionally select communication backends (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# Then specify 'backend="cudecomp"' in your FFT or halo calls:
karray = jaxdecomp.fft.pfft3d(global_array, backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray, backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
    padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp'
)

Install

1. Pure JAX Version (Easy / Recommended)

jaxDecomp is on PyPI:

  1. Install the appropriate JAX wheel:
    • GPU:
      pip install --upgrade "jax[cuda]"
      
    • CPU:
      pip install --upgrade "jax[cpu]"
      
  2. Install jaxdecomp:
    pip install jaxdecomp
    

This setup uses the pure-JAX backend—no MPI required.

2. JAX + cuDecomp Backend (Advanced)

If you need to use MPI instead of NCCL for GPU, you can build from GitHub with cuDecomp enabled. This requires the NVIDIA HPC SDK. Ensure nvc, nvc++, and nvcc are in your PATH, CUDA, MPI, and NCCL shared libraries are on LD_LIBRARY_PATH, and set CC=nvc and CXX=nvc++ before building.

pip install -U pip
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON

Alternatively, clone the repository locally and install from your checkout:

git clone https://github.com/DifferentiableUniverseInitiative/jaxDecomp.git --recursive
cd jaxDecomp
pip install -U pip
pip install . -Ccmake.define.JD_CUDECOMP_BACKEND=ON
  • If CMake cannot find NVHPC, set:
    export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
    
    and then install again.

If jax complains about incompatiliby with CuSparse or any other library, the easiest way to solve this is by installing jax localy by running pip install jax[cuda-local] and then installing jaxDecomp with cuDecomp support.


Machine-Specific Notes

IDRIS Jean Zay HPE SGI 8600 supercomputer

As of February 2025, loading modules in this exact order works:

module load nvidia-compilers/25.1 cuda/12.6.3 openmpi/4.1.6-cuda nccl/2.26.2-1-cuda cudnn  cmake
# Install JAX
pip install --upgrade "jax[cuda-local]"

# Install jaxDecomp with cuDecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # sometimes needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON

Note: If using only the pure-JAX backend, you do not need NVHPC.

Important for JeanZay users Make sure to load the correct architucture module before loading the nvidia-compilers module. For example for A100 you need to load module load arch/a100 first. You also need to set the CXXFLAGS to export CXXFLAGS="-tp=zen2 -noswitcherror" if you are using the H100 or A100 partition or if you are using AMD CPUs in general. More info in Jean Zay documentation.

NERSC Perlmutter HPE Cray EX supercomputer

As of November 2022:

module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80

# Install JAX
pip install --upgrade "jax[cuda]"

# Install jaxDecomp w/ cuDecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON

Backend Configuration (cuDecomp Only)

By default, cuDecomp uses NCCL for inter-device communication. You can customize this at runtime:

import jaxdecomp

# Choose MPI or NVSHMEM for halo and transpose ops
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)

This can also be managed via environment variables, as described in the docs.


Autotune Computational Mesh (cuDecomp Only)

The cuDecomp library can autotune the partition layout to maximize performance:

automesh = jaxdecomp.autotune(shape=[512,512,512])
# 'automesh' is an optimized partition layout.
# You can then create a JAX Sharding spec from this:
from jax.sharding import PositionalSharding
sharding = PositionalSharding(automesh)

License: This project is licensed under the MIT License.

For more details, see the examples directory and the documentation. Contributions and issues are welcome!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxdecomp-0.3.0.tar.gz (42.7 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

jaxdecomp-0.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114.8 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

jaxdecomp-0.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (119.0 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

jaxdecomp-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (118.6 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

File details

Details for the file jaxdecomp-0.3.0.tar.gz.

File metadata

  • Download URL: jaxdecomp-0.3.0.tar.gz
  • Upload date:
  • Size: 42.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jaxdecomp-0.3.0.tar.gz
Algorithm Hash digest
SHA256 c7ad8cb0b0132e73edda0b7d5a2e9658cbc04c4adeea61756fd76dedff380d48
MD5 53c7aa917a0d8a6ec81028560f9e1b14
BLAKE2b-256 8bea8d55d8d05b8467fdc55c2289429d91886f309fd82ef0b0dad308b8d63b17

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.3.0.tar.gz:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 1b4ff1ced1bf1fb1ade16345d7c3767b0b3b75807c9a5909dd1e474f4685bbb7
MD5 cdb710d18a10250202f0310e331668d8
BLAKE2b-256 059104299ad85a113ec6c21853356729a7308da88f21cfe97d9c1901ccfb6f1a

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 81da65c2ae393b9b00913dff68c9c58ce1b4efee1dbc581e358b5300bde60d32
MD5 61a3a03fe30d64e8558002b7217ade64
BLAKE2b-256 b7039b694ad4afe99f0568d4ed2c533f4c8b00de9654780047ab06fad34da9ef

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 cc87f2d72d89db97fc6e09b14038b0c4170b93fcc3ed2cc8716087f0d47b1a1e
MD5 8c2e4ef7c67faf4b28c1a316bad7d3de
BLAKE2b-256 72e737a5bd1ea940aabbd80c7a81569fc8044d967bc87ca172ea8e266a65fa45

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page