Skip to main content

JAX bindings for the cuDecomp library

Project description

jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs

Build Code Formatting Tests MIT License

Important Version 0.2.0 includes a pure JAX backend that no longer requires MPI. For multi-node runs, MPI and NCCL backends are still available through cuDecomp.

JAX reimplementation and bindings for NVIDIA's cuDecomp library (Romero et al. 2022), enabling multi-node parallel FFTs and halo exchanges directly in low-level NCCL/CUDA-Aware MPI from your JAX code.


Usage

Below is a simple code snippet illustrating how to perform a 3D FFT on a distributed 3D array, followed by a halo exchange. For demonstration purposes, we force 8 CPU devices via environment variables:

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jaxdecomp

# Create a 2x4 mesh of devices on CPU
pdims = (2, 4)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

# Create a random 3D array and enforce sharding
a = jax.random.normal(jax.random.PRNGKey(0), (1024, 1024, 1024))
a = jax.lax.with_sharding_constraint(a, sharding)

# Parallel FFTs
k_array = jaxdecomp.fft.pfft3d(a)
rec_array = jaxdecomp.fft.pifft3d(a)

# Parallel halo exchange
exchanged = jaxdecomp.halo_exchange(a, halo_extents=(16, 16), halo_periods=(True, True))

All these functions are JIT-compatible and support automatic differentiation (with some caveats).

See also:

Important Multi-node FFTs work with both JAX and cuDecomp backends
For CPU with JAX, Multi-node is supported starting JAX v0.5.1 (with gloo backend)


Running on an HPC Cluster

On HPC clusters (e.g., Jean Zay, Perlmutter), you typically launch your script with:

srun python demo.py

or

mpirun -n 8 python demo.py

See the Slurm README and template script for more details.


Using cuDecomp (MPI and NCCL)

For other features, compile and install with cuDecomp enabled as described in install:

import jaxdecomp

# Optionally select communication backends (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# Then specify 'backend="cudecomp"' in your FFT or halo calls:
karray = jaxdecomp.fft.pfft3d(global_array, backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray, backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
    padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp'
)

Install

1. Pure JAX Version (Easy / Recommended)

jaxDecomp is on PyPI:

  1. Install the appropriate JAX wheel:
    • GPU:
      pip install --upgrade "jax[cuda]"
      
    • CPU:
      pip install --upgrade "jax[cpu]"
      
  2. Install jaxdecomp:
    pip install jaxdecomp
    

This setup uses the pure-JAX backend—no MPI required.

2. JAX + cuDecomp Backend (Advanced)

If you need to use MPI instead of NCCL for GPU or gloo for CPU, you can build from GitHub with cuDecomp enabled. This requires the NVIDIA HPC SDK or a similar environment providing a CUDA-aware MPI toolchain.

pip install -U pip
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON
  • If CMake cannot find NVHPC, set:
    export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
    
    and then install again.

Machine-Specific Notes

IDRIS Jean Zay HPE SGI 8600 supercomputer

As of February 2025, loading modules in this exact order works:

module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake

# Install JAX
pip install --upgrade "jax[cuda]"

# Install jaxDecomp with cuDecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # sometimes needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON

Note: If using only the pure-JAX backend, you do not need NVHPC.

NERSC Perlmutter HPE Cray EX supercomputer

As of November 2022:

module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80

# Install JAX
pip install --upgrade "jax[cuda]"

# Install jaxDecomp w/ cuDecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

Backend Configuration (cuDecomp Only)

By default, cuDecomp uses NCCL for inter-device communication. You can customize this at runtime:

import jaxdecomp

# Choose MPI or NVSHMEM for halo and transpose ops
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)

This can also be managed via environment variables, as described in the docs.


Autotune Computational Mesh (cuDecomp Only)

The cuDecomp library can autotune the partition layout to maximize performance:

automesh = jaxdecomp.autotune(shape=[512,512,512])
# 'automesh' is an optimized partition layout.
# You can then create a JAX Sharding spec from this:
from jax.sharding import PositionalSharding
sharding = PositionalSharding(automesh)

License: This project is licensed under the MIT License.

For more details, see the examples directory and the documentation. Contributions and issues are welcome!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxdecomp-0.2.6.tar.gz (16.0 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

jaxdecomp-0.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (155.7 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

jaxdecomp-0.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (156.3 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

jaxdecomp-0.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (154.8 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

File details

Details for the file jaxdecomp-0.2.6.tar.gz.

File metadata

  • Download URL: jaxdecomp-0.2.6.tar.gz
  • Upload date:
  • Size: 16.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxdecomp-0.2.6.tar.gz
Algorithm Hash digest
SHA256 8516c1e5d742ca9cfeb64b0091aabf697dc607ed68d85937060c415c0fe04c33
MD5 fb69cd2757ca9242227b9f8e8d4cee5a
BLAKE2b-256 a91a8c0c1ad65a401fef3acda14046519a54df4f70bc68688d63eed3243a10d6

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.6.tar.gz:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ca0fc269379816cc0e6c1788a2eac76793a6d8ebcc8d12ae73b0dd341c2a8673
MD5 5996dded634b7e7779f79fa75b80c590
BLAKE2b-256 4ccc44231f1d63208fef5a9bdd67b5d003f9c377be9796b18c9671f5c03ee2b9

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c77ab40c4e7e2916bfc0787fee242cdb5aef98da3232b8536ec473167c974da8
MD5 7034b1ad7064e5f78ac389c04682f2ec
BLAKE2b-256 4c0b4a6c5c687ba6c369c669a4ce05248b7a83a296096ddbc7565d9cd99fe8a1

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 54ce996d9a257dab2a6f5de70f8cae10377272fda656ebbe22f8d9153a595659
MD5 9a085dde8ea836e183ff08ec1ad70586
BLAKE2b-256 9fa85b8e8565da7b26a6fec4d9ca919be268a1716037de12ee6802f76d26b551

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page