Skip to main content

PM++: a JAX-based differentiable multi-GPU Particle-Mesh cosmology simulator.

Project description

PM++: Multi-GPU Particle-Mesh Cosmology

Documentation Status

PM++ is a JAX-based, differentiable particle-mesh cosmology code built on top of PMWD ideas and extended for multi-GPU simulations. The active implementation is imported as pmpp and lives in src/pmpp/; the pmwd/ directory is kept as a reference implementation for validation.

Current Scope

  • Multi-GPU PM N-body simulation with JAX.
  • Preferred mesh_halo multi-GPU mode.
  • PMWD comparison tests for forward and gradient correctness.
  • Distributed FFT support for sharded meshes.
  • LPT, Boltzmann/growth utilities, scatter/gather, and power-spectrum tools.
  • Potential-correction models under src/pmpp/corrections/.

Repository Layout

PMpp/
|-- src/pmpp/                    # Active importable PM++ package
|   |-- configuration.py         # Simulation configuration
|   |-- multigpu_configuration.py# Multi-GPU mode/configuration object
|   |-- particles.py             # Particle state and ownership
|   |-- scatter.py               # Particle-to-mesh assignment
|   |-- gather.py                # Mesh-to-particle interpolation
|   |-- gravity.py               # PM force solve
|   |-- steps.py                 # Drift, kick, force, adjoint pieces
|   |-- nbody.py                 # Full N-body integration and VJP
|   |-- FFT_distributed.py       # Distributed FFT construction
|   |-- mesh_halo.py             # Mesh halo exchange helpers
|   |-- modes.py                 # White noise and linear modes
|   |-- lpt.py                   # LPT initialization
|   |-- power_spectrum.py        # Density and particle P(k)
|   |-- corrections/             # Potential corrections
|   `-- potential_correction.py  # Backward-compatible correction facade
|-- pmwd/                        # Reference PMWD implementation
|-- tests/                       # Regression and gradient tests
|-- scripts/                     # Benchmarks and diagnostics
|-- notebooks/                   # Examples and exploratory notebooks
`-- docs/                        # Project documentation

Minimal Multi-GPU Setup

New code should use the nested MultiGPUConfiguration object. The older top-level compute_mesh= compatibility path still exists, but is not preferred.

import jax
import jax.numpy as jnp

from pmpp.configuration import Configuration
from pmpp.multigpu_configuration import MultiGPUConfiguration
from pmpp.utils import create_compute_mesh

res = 256
box_size = 1000.0  # Mpc/h
ptcl_grid_shape = (res, res, res)
ptcl_spacing = box_size / res

gpu_devices = [device for device in jax.devices() if device.platform == "gpu"]
if len(gpu_devices) < 2:
    raise RuntimeError("This multi-GPU example requires at least 2 GPUs.")
compute_mesh = create_compute_mesh(gpu_devices)
num_devices = len(gpu_devices)

conf = Configuration(
    ptcl_spacing,
    ptcl_grid_shape,
    mesh_shape=1,
    multigpu=MultiGPUConfiguration(
        compute_mesh=compute_mesh,
        mode="mesh_halo",
    ),
    max_ptcl_per_slice=int((res**3 / num_devices) * 1.8),
    max_share_ptcl=50_000,
    max_halo_share_ptcl=50_000,
    max_share_gather_ptcl=200_000,
    float_dtype=jnp.float32,
)

Capacity overflows are correctness failures. If a run reports overflow in particle migration, halo rebuild, or gather exchange buffers, increase the corresponding capacity and rerun.

Minimal Forward Run

import jax
import jax.numpy as jnp

from pmpp.boltzmann import boltzmann
from pmpp.configuration import Configuration
from pmpp.cosmo import SimpleLCDM
from pmpp.lpt import lpt
from pmpp.modes import linear_modes, white_noise
from pmpp.nbody import nbody
from pmpp.scatter import scatter

res = 32
box_size = 100.0
conf = Configuration(
    box_size / res,
    (res, res, res),
    mesh_shape=1,
    float_dtype=jnp.float32,
)

cosmo = boltzmann(SimpleLCDM(conf), conf)
modes = white_noise(0, conf)
modes = linear_modes(modes, cosmo, conf)
ptcl = lpt(modes, cosmo, conf)

nbody_jit = jax.jit(nbody, static_argnames=("conf", "reverse"))
ptcl_final = nbody_jit(ptcl, cosmo, conf)
density = scatter(ptcl_final, conf)

print(density.shape)
print(float(density.mean()))

Expected sanity checks:

  • density shape matches the mesh;
  • density mean is close to 1.0;
  • no capacity warnings appear.

Potential Corrections

Correction implementations now live in pmpp.corrections (src/pmpp/corrections/).

from pmpp.corrections import (
    apply_potential_correction,
    evaluate_potential_transfer,
    init_potential_correction,
)

pmpp.potential_correction remains as a compatibility facade for old scripts, but new code and tests should import from pmpp.corrections.

Supported correction families:

  • radial, radial_spline, neural_spline
  • mesh_cnn, cnn
  • combined, hybrid, spline_cnn
  • pm_window, cic_compensation, cic_window_compensation

Multi-GPU Modes

Prefer mesh_halo for current serious multi-GPU work:

  • particles are stored authoritatively on their owning slab;
  • particles migrate between slabs when needed;
  • mesh halos are exchanged for local stencil operations;
  • it is generally faster than the older particle-halo path in current 256^3, 2-GPU testing.

particle_halo remains useful for comparison and legacy validation.

Testing

Focused correction and gravity checks:

/home/rouzib/.virtualenvs/PMPP/bin/python -m pytest \
  tests/test_potential_correction.py \
  tests/test_grad_gravity.py \
  tests/test_gravity_particle_nyquist_filter.py \
  -q

Mesh-halo scatter/gather:

/home/rouzib/.virtualenvs/PMPP/bin/python -m pytest tests/test_mesh_halo_scatter_gather.py -q

End-to-end gradient:

/home/rouzib/.virtualenvs/PMPP/bin/python -m pytest tests/test_grad_nbody.py -q

Notebooks

The primary example notebooks are:

  • notebooks/pmpp_showcase.ipynb
  • notebooks/mGPU_pmwd_local.ipynb

Restart notebook kernels after code changes. Stale kernels can keep old module objects, especially around pmpp.corrections and multi-GPU configuration.

License

PM++ is distributed under the BSD-3-Clause license; see LICENSE. PM++ is based on PMWD and retains the original PMWD BSD 3-Clause notice in THIRD_PARTY_NOTICES.md. The pmwd/ directory is kept as a reference implementation for validation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pmpp-0.1.3.tar.gz (121.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pmpp-0.1.3-py3-none-any.whl (141.5 kB view details)

Uploaded Python 3

File details

Details for the file pmpp-0.1.3.tar.gz.

File metadata

  • Download URL: pmpp-0.1.3.tar.gz
  • Upload date:
  • Size: 121.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for pmpp-0.1.3.tar.gz
Algorithm Hash digest
SHA256 3bc18a353d9deb10595f7df9807f00714488d8cc6a1c1bfc623f4aa596293354
MD5 9c5e8852e4c3714311a09d67fce4cc48
BLAKE2b-256 d1421265a29dcb16ea4e18321aafaf18ea09f816928f549e17e2ac576e6a8e71

See more details on using hashes here.

Provenance

The following attestation bundles were made for pmpp-0.1.3.tar.gz:

Publisher: publish-to-pypi.yml on rouzib/PMpp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pmpp-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: pmpp-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 141.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for pmpp-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7d2952d70e6b817ab9bce9c5f6d937b7f58057fdf57fbace1cfb5abaeddf0c41
MD5 6d386bf74929a8323cf97afd10dd0366
BLAKE2b-256 7a827b340e831e48b06807869d2f4d536161dc24469da0499336c93d7c2ca72a

See more details on using hashes here.

Provenance

The following attestation bundles were made for pmpp-0.1.3-py3-none-any.whl:

Publisher: publish-to-pypi.yml on rouzib/PMpp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page