Skip to main content

JAX bindings for the cuDecomp library

Project description

jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs

Code Formatting Tests MIT License

[!IMPORTANT] Version 0.2.0 has a pure JAX backend and no longer requires MPI .. MPI and NCCL backends are still available through cuDecomp

JAX bindings for NVIDIA's cuDecomp library (Romero et al. 2022), allowing for efficient multi-node parallel FFTs and halo exchanges directly in low level NCCL/CUDA-Aware MPI from your JAX code :tada:

Usage

Here is an example of how to use jaxDecomp to perform a 3D FFT on a 3D array distributed across multiple GPUs. This example also includes a halo exchange operation, which is a common operation in many scientific computing applications.

import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import numpy as jnp
import jaxdecomp
from functools import partial
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
rank = jax.process_index()

# Setup a processor mesh (should be same size as "size")
pdims = (2, 4)
global_shape = (1024, 1024, 1024)

# Initialize an array with the expected gobal size
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
               global_shape[2])
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
global_array = jax.make_array_from_callback(
    global_shape,
    sharding,
    data_callback=lambda _: jax.random.normal(
        jax.random.PRNGKey(rank), local_shape))

padding_width = ((32, 32), (32, 32), (0, 0))  # Has to a tuple of tuples


@partial(
    shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def pad(arr, padding):
  return jnp.pad(arr, padding)


@partial(
    shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def reduce_halo(x, pad_width):

  halo_x , _ = pad_width[0]
  halo_y , _ = pad_width[1]
  # Apply corrections along x
  x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
  x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
  # Apply corrections along y
  x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
  x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])

  return x[halo_x:-halo_x, halo_y:-halo_y]


@jax.jit
def modify_array(array):
  return 2 * array + 1


# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Do some operation on your array
karray = modify_array(karray)
kvec = jaxdecomp.fft.fftfreq3d(karray)
# Do a gradient in the X axis
karray_gradient = 1j * kvec[0] * karray
# Reverse FFT
recarray = jaxdecomp.fft.pifft3d(karray_gradient).real
# Add halo regions to our array
padded_array = pad(recarray, padding_width)
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(
    padded_array, halo_extents=(16, 16), halo_periods=(True, True))
# Reduce the halo regions and remove the padding
reduced_array = reduce_halo(exchanged_array, padding_width)

# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)

# Finalize the distributed JAX
jax.distributed.shutdown()

Note: All these functions are jittable and have well defined derivatives!

This script would have to be run on 8 GPUs in total with something like

mpirun -n 8 python demo.py

or on a slurm cluster like Jean jean-zay

srun -n 8 python demo.py

Using cuDecomp (MPI and NCCL)

You can also use the cuDecomp backend by compiling the library with the right flag (check the installation instructions) and setting the backend to use MPI or NCCL. Here is how you can do it:

import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# and then call the functions with the cuDecomp backends
karray = jaxdecomp.fft.pfft3d(global_array , backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray , backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
    padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp')

please check the tests in tests folder for more examples.

On an HPC cluster like Jean Zay you should do this

$ srun python demo.py

Check the slurm README and template for more information on how to run on a Jean Zay.

Install

Installing the pure JAX version (Easy)

jaxDecomp is available on pypi and can be installed via pip:

First install desired JAX version

For GPU

pip install -U jax[cuda12]

For CPU

pip install -U jax[cpu]

Then you can pip install jaxdecomp

pip install jaxdecomp

Installing JAX and cuDecomp (Advanced)

You need to install from this github after installing or loading the correct modules

This install procedure assumes that the NVIDIA HPC SDK is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.

Building jaxDecomp

pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

If CMake complains of not finding the NVHPC SDK, you can manually specify the location of the sdk's cmake files like so:

$ export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
$ pip install --user .

Specific Install Notes for Specific Machines

IDRIS Jean Zay HPE SGI 8600 supercomputer

As of October. 2024, the following works:

You need to load modules in that order exactly.

# Load NVHPC 23.9 because it has cuda 12.2
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda  openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

Note: This is needed only if you want to use the cuDecomp backend. If you are using the pure JAX backend, you can skip the NVHPC SDK installation and just pip install jaxdecomp after installing the correct JAX version for your hardware.

NERSC Perlmutter HPE Cray EX supercomputer

As of Nov. 2022, the following works:

module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

Backend configuration (Only for cuDecomp)

Note: For the JAX backend, only NCCL is available.

We can set the default communication backend to use for cuDecomp operations either through a config module, or environment variables. This will allow the users to choose at startup (although can be changed afterwards) the communication backend, making it possible to use CUDA-aware MPI or NVSHMEM as preferred.

Here is how it would like:

jaxdecomp.config.update('transpose_comm_backend', 'NCCL')
# We could for instance time how long it takes to execute in this mode
%timeit pfft3(y)

# And then update the backend
jaxdecomp.config.update('transpose_comm_backend', 'MPI')
# And measure again
%timeit pfft3(y)

Autotune computational mesh (Only for cuDecomp)

We can also make things fancier, since cuDecomp is able to autotune, we could use it to tell us what is the best way to partition the data given the available GPUs, something like this:

automesh = jaxdecomp.autotune(shape=[512,512,512])
# This is a JAX Sharding spec object, optimized for the given GPUs
# and shape of the tensor
sharding = PositionalSharding(automesh)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxdecomp-0.2.0.tar.gz (16.0 MB view details)

Uploaded Source

Built Distributions

jaxdecomp-0.2.0-py3-none-musllinux_1_2_x86_64.whl (1.1 MB view details)

Uploaded Python 3 musllinux: musl 1.2+ x86-64

jaxdecomp-0.2.0-py3-none-musllinux_1_2_i686.whl (1.2 MB view details)

Uploaded Python 3 musllinux: musl 1.2+ i686

jaxdecomp-0.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (153.9 kB view details)

Uploaded Python 3 manylinux: glibc 2.17+ x86-64

jaxdecomp-0.2.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl (162.1 kB view details)

Uploaded Python 3 manylinux: glibc 2.17+ i686

File details

Details for the file jaxdecomp-0.2.0.tar.gz.

File metadata

  • Download URL: jaxdecomp-0.2.0.tar.gz
  • Upload date:
  • Size: 16.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for jaxdecomp-0.2.0.tar.gz
Algorithm Hash digest
SHA256 de1a2d9288cd84256005acc9e56cea3e6f2dbc3be222cf0646888570dd189183
MD5 b679667f76057fb246bedcf9ad96e89b
BLAKE2b-256 1b0e059e89f00e0879f788ac6c1c319243e1c8dd1ac5b6e75bd5d80175d201c3

See more details on using hashes here.

File details

Details for the file jaxdecomp-0.2.0-py3-none-musllinux_1_2_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.0-py3-none-musllinux_1_2_x86_64.whl
Algorithm Hash digest
SHA256 17fbc6cfc485e6851832354c25398e3b23dd3c1a50db4e0a52bfae766c2f1659
MD5 93b2d35eb46cf3be8685537466295cbe
BLAKE2b-256 3f6e89b9bd5cbd8519c62d53fdaf0d4dca103f89bd0334075eda7ebb41055818

See more details on using hashes here.

File details

Details for the file jaxdecomp-0.2.0-py3-none-musllinux_1_2_i686.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.0-py3-none-musllinux_1_2_i686.whl
Algorithm Hash digest
SHA256 c567c7800f31adf0ee58491b4df7fcd34189ab467a1117be78a38a213674ab27
MD5 e05567898676c524ecd64b304eced01d
BLAKE2b-256 2e52aaa35e309ae78204633a6784c58225051e9023f134548df3347f409268a1

See more details on using hashes here.

File details

Details for the file jaxdecomp-0.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 35fcfa8a79ae3ddb0d693326de7d98ee811fd28b74e3aaf812020fce277c12ac
MD5 7d97b696a3f4fc729c182133a039dc09
BLAKE2b-256 623d22d5fabd5964877a653e9dc60191edc5ebda5e02dd0dbf249d2b06ddc999

See more details on using hashes here.

File details

Details for the file jaxdecomp-0.2.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl
Algorithm Hash digest
SHA256 ebe19a62e43483053c21d00d3411b60f239d3fda6ceacba618a0e6dcae4f63f2
MD5 739d44a0e376094b1d0e98feaf46b90a
BLAKE2b-256 377136b7880e491fa8f0abe893a4ce4ba153cc917320cd472d4038d94cacd3ca

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page