Skip to main content

JAX bindings for the cuDecomp library

Project description

jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs

Build Code Formatting Tests MIT License

[!IMPORTANT] Version 0.2.0 has a pure JAX backend and no longer requires MPI .. MPI and NCCL backends are still available through cuDecomp

JAX bindings for NVIDIA's cuDecomp library (Romero et al. 2022), allowing for efficient multi-node parallel FFTs and halo exchanges directly in low level NCCL/CUDA-Aware MPI from your JAX code :tada:

Usage

Here is an example of how to use jaxDecomp to perform a 3D FFT on a 3D array distributed across multiple GPUs. This example also includes a halo exchange operation, which is a common operation in many scientific computing applications.

import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import numpy as jnp
import jaxdecomp
from functools import partial
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
rank = jax.process_index()

# Setup a processor mesh (should be same size as "size")
pdims = (2, 4)
global_shape = (1024, 1024, 1024)

# Initialize an array with the expected gobal size
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
               global_shape[2])
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
global_array = jax.make_array_from_callback(
    global_shape,
    sharding,
    data_callback=lambda _: jax.random.normal(
        jax.random.PRNGKey(rank), local_shape))

padding_width = ((32, 32), (32, 32), (0, 0))  # Has to a tuple of tuples


@partial(
    shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def pad(arr, padding):
  return jnp.pad(arr, padding)


@partial(
    shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def reduce_halo(x, pad_width):

  halo_x , _ = pad_width[0]
  halo_y , _ = pad_width[1]
  # Apply corrections along x
  x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
  x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
  # Apply corrections along y
  x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
  x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])

  return x[halo_x:-halo_x, halo_y:-halo_y]


@jax.jit
def modify_array(array):
  return 2 * array + 1


# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Do some operation on your array
karray = modify_array(karray)
kvec = jaxdecomp.fft.fftfreq3d(karray)
# Do a gradient in the X axis
karray_gradient = 1j * kvec[0] * karray
# Reverse FFT
recarray = jaxdecomp.fft.pifft3d(karray_gradient).real
# Add halo regions to our array
padded_array = pad(recarray, padding_width)
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(
    padded_array, halo_extents=(16, 16), halo_periods=(True, True))
# Reduce the halo regions and remove the padding
reduced_array = reduce_halo(exchanged_array, padding_width)

# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)

# Finalize the distributed JAX
jax.distributed.shutdown()

Note: All these functions are jittable and have well defined derivatives!

This script would have to be run on 8 GPUs in total with something like

mpirun -n 8 python demo.py

or on a slurm cluster like Jean jean-zay

srun -n 8 python demo.py

Using cuDecomp (MPI and NCCL)

You can also use the cuDecomp backend by compiling the library with the right flag (check the installation instructions) and setting the backend to use MPI or NCCL. Here is how you can do it:

import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# and then call the functions with the cuDecomp backends
karray = jaxdecomp.fft.pfft3d(global_array , backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray , backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
    padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp')

please check the tests in tests folder for more examples.

On an HPC cluster like Jean Zay you should do this

$ srun python demo.py

Check the slurm README and template for more information on how to run on a Jean Zay.

Install

Installing the pure JAX version (Easy)

jaxDecomp is available on pypi and can be installed via pip:

First install desired JAX version

For GPU

pip install -U jax[cuda12]

For CPU

pip install -U jax[cpu]

Then you can pip install jaxdecomp

pip install jaxdecomp

Installing JAX and cuDecomp (Advanced)

You need to install from this github after installing or loading the correct modules

This install procedure assumes that the NVIDIA HPC SDK is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.

Building jaxDecomp

pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

If CMake complains of not finding the NVHPC SDK, you can manually specify the location of the sdk's cmake files like so:

$ export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
$ pip install --user .

Specific Install Notes for Specific Machines

IDRIS Jean Zay HPE SGI 8600 supercomputer

As of October. 2024, the following works:

You need to load modules in that order exactly.

# Load NVHPC 23.9 because it has cuda 12.2
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda  openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

Note: This is needed only if you want to use the cuDecomp backend. If you are using the pure JAX backend, you can skip the NVHPC SDK installation and just pip install jaxdecomp after installing the correct JAX version for your hardware.

NERSC Perlmutter HPE Cray EX supercomputer

As of Nov. 2022, the following works:

module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON

Backend configuration (Only for cuDecomp)

Note: For the JAX backend, only NCCL is available.

We can set the default communication backend to use for cuDecomp operations either through a config module, or environment variables. This will allow the users to choose at startup (although can be changed afterwards) the communication backend, making it possible to use CUDA-aware MPI or NVSHMEM as preferred.

Here is how it would like:

jaxdecomp.config.update('transpose_comm_backend', 'NCCL')
# We could for instance time how long it takes to execute in this mode
%timeit pfft3(y)

# And then update the backend
jaxdecomp.config.update('transpose_comm_backend', 'MPI')
# And measure again
%timeit pfft3(y)

Autotune computational mesh (Only for cuDecomp)

We can also make things fancier, since cuDecomp is able to autotune, we could use it to tell us what is the best way to partition the data given the available GPUs, something like this:

automesh = jaxdecomp.autotune(shape=[512,512,512])
# This is a JAX Sharding spec object, optimized for the given GPUs
# and shape of the tensor
sharding = PositionalSharding(automesh)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxdecomp-0.2.2.tar.gz (16.0 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

jaxdecomp-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (154.8 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

jaxdecomp-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (155.4 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.17+ x86-64

jaxdecomp-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (154.0 kB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

File details

Details for the file jaxdecomp-0.2.2.tar.gz.

File metadata

  • Download URL: jaxdecomp-0.2.2.tar.gz
  • Upload date:
  • Size: 16.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for jaxdecomp-0.2.2.tar.gz
Algorithm Hash digest
SHA256 f80e4f941ec6d9c57aa97ea7fbc6a0e9306c7fd9f704796d97032997c8d6929f
MD5 08707d01e555929424790da7c705a3ec
BLAKE2b-256 552718919e1eb3883ec8288b328c9f134011474718359d2a1c89f1c47f891b86

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.2.tar.gz:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4cbc4b63730f0ec0be35647687b874f895616fefeca4ce941c6c112738ff951e
MD5 71845bf3a5694dfbd3eb2f24d1d55daa
BLAKE2b-256 f8f3203c565eaa09c577f1a2f91b16779f8e0e1ba7c971f7188c9b26123d7c17

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 2ed12ddea97c466871b1bc164594f2d5cdadd739f9be4f2d6defe955760b7cdf
MD5 eef9e544a5c3f716ecd48b13b3c1f4a0
BLAKE2b-256 ef828a294d435372c59431b6bcedd4a289137d20371521adbe9b012dfa27a79e

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxdecomp-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for jaxdecomp-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 af30985d3f6eabbb62ba3fed83107dc64caea2602809448f29429c12479562be
MD5 00f93a2c3ab3801069f4500dfc9ec17e
BLAKE2b-256 2ed291bf8865de48d606f8d615bf6dfeb2f75b79cc8ddc6cdee77669f7022b88

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxdecomp-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: github-deploy.yml on DifferentiableUniverseInitiative/jaxDecomp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page