PM++: a JAX-based differentiable multi-GPU Particle-Mesh cosmology simulator.
Project description
PM++: Multi-GPU Particle-Mesh Cosmology
PM++ is a JAX-based, differentiable particle-mesh cosmology code built on top
of PMWD ideas and extended for multi-GPU simulations. The active implementation
is imported as pmpp and lives in src/pmpp/; the pmwd/ directory is kept as a reference implementation for
validation.
Current Scope
- Multi-GPU PM N-body simulation with JAX.
- Preferred
mesh_halomulti-GPU mode. - PMWD comparison tests for forward and gradient correctness.
- Distributed FFT support for sharded meshes.
- LPT, Boltzmann/growth utilities, scatter/gather, and power-spectrum tools.
- Potential-correction models under
src/pmpp/corrections/.
Repository Layout
PMpp/
|-- src/pmpp/ # Active importable PM++ package
| |-- configuration.py # Simulation configuration
| |-- multigpu_configuration.py# Multi-GPU mode/configuration object
| |-- particles.py # Particle state and ownership
| |-- scatter.py # Particle-to-mesh assignment
| |-- gather.py # Mesh-to-particle interpolation
| |-- gravity.py # PM force solve
| |-- steps.py # Drift, kick, force, adjoint pieces
| |-- nbody.py # Full N-body integration and VJP
| |-- FFT_distributed.py # Distributed FFT construction
| |-- mesh_halo.py # Mesh halo exchange helpers
| |-- modes.py # White noise and linear modes
| |-- lpt.py # LPT initialization
| |-- power_spectrum.py # Density and particle P(k)
| |-- corrections/ # Potential corrections
| `-- potential_correction.py # Backward-compatible correction facade
|-- pmwd/ # Reference PMWD implementation
|-- tests/ # Regression and gradient tests
|-- scripts/ # Benchmarks and diagnostics
|-- notebooks/ # Examples and exploratory notebooks
`-- docs/ # Project documentation
Minimal Multi-GPU Setup
New code should use the nested MultiGPUConfiguration object. The older
top-level compute_mesh= compatibility path still exists, but is not preferred.
import jax
import jax.numpy as jnp
from pmpp.configuration import Configuration
from pmpp.multigpu_configuration import MultiGPUConfiguration
from pmpp.utils import create_compute_mesh
res = 256
box_size = 1000.0 # Mpc/h
ptcl_grid_shape = (res, res, res)
ptcl_spacing = box_size / res
gpu_devices = [device for device in jax.devices() if device.platform == "gpu"]
if len(gpu_devices) < 2:
raise RuntimeError("This multi-GPU example requires at least 2 GPUs.")
compute_mesh = create_compute_mesh(gpu_devices)
num_devices = len(gpu_devices)
conf = Configuration(
ptcl_spacing,
ptcl_grid_shape,
mesh_shape=1,
multigpu=MultiGPUConfiguration(
compute_mesh=compute_mesh,
mode="mesh_halo",
),
max_ptcl_per_slice=int((res**3 / num_devices) * 1.8),
max_share_ptcl=50_000,
max_halo_share_ptcl=50_000,
max_share_gather_ptcl=200_000,
float_dtype=jnp.float32,
)
Capacity overflows are correctness failures. If a run reports overflow in particle migration, halo rebuild, or gather exchange buffers, increase the corresponding capacity and rerun.
Minimal Forward Run
import jax
import jax.numpy as jnp
from pmpp.boltzmann import boltzmann
from pmpp.configuration import Configuration
from pmpp.cosmo import SimpleLCDM
from pmpp.lpt import lpt
from pmpp.modes import linear_modes, white_noise
from pmpp.nbody import nbody
from pmpp.scatter import scatter
res = 32
box_size = 100.0
conf = Configuration(
box_size / res,
(res, res, res),
mesh_shape=1,
float_dtype=jnp.float32,
)
cosmo = boltzmann(SimpleLCDM(conf), conf)
modes = white_noise(0, conf)
modes = linear_modes(modes, cosmo, conf)
ptcl = lpt(modes, cosmo, conf)
nbody_jit = jax.jit(nbody, static_argnames=("conf", "reverse"))
ptcl_final = nbody_jit(ptcl, cosmo, conf)
density = scatter(ptcl_final, conf)
print(density.shape)
print(float(density.mean()))
Expected sanity checks:
- density shape matches the mesh;
- density mean is close to
1.0; - no capacity warnings appear.
Potential Corrections
Correction implementations now live in pmpp.corrections (src/pmpp/corrections/).
from pmpp.corrections import (
apply_potential_correction,
evaluate_potential_transfer,
init_potential_correction,
)
pmpp.potential_correction remains as a compatibility facade for old scripts,
but new code and tests should import from pmpp.corrections.
Supported correction families:
radial,radial_spline,neural_splinemesh_cnn,cnncombined,hybrid,spline_cnnpm_window,cic_compensation,cic_window_compensation
Multi-GPU Modes
Prefer mesh_halo for current serious multi-GPU work:
- particles are stored authoritatively on their owning slab;
- particles migrate between slabs when needed;
- mesh halos are exchanged for local stencil operations;
- it is generally faster than the older particle-halo path in current
256^3, 2-GPU testing.
particle_halo remains useful for comparison and legacy validation.
Testing
Focused correction and gravity checks:
/home/rouzib/.virtualenvs/PMPP/bin/python -m pytest \
tests/test_potential_correction.py \
tests/test_grad_gravity.py \
tests/test_gravity_particle_nyquist_filter.py \
-q
Mesh-halo scatter/gather:
/home/rouzib/.virtualenvs/PMPP/bin/python -m pytest tests/test_mesh_halo_scatter_gather.py -q
End-to-end gradient:
/home/rouzib/.virtualenvs/PMPP/bin/python -m pytest tests/test_grad_nbody.py -q
Notebooks
The primary example notebooks are:
notebooks/pmpp_showcase.ipynbnotebooks/mGPU_pmwd_local.ipynb
Restart notebook kernels after code changes. Stale kernels can keep old module
objects, especially around pmpp.corrections and multi-GPU configuration.
License
PM++ is distributed under the BSD-3-Clause license; see LICENSE.
PM++ is based on PMWD and retains the original PMWD BSD 3-Clause notice in
THIRD_PARTY_NOTICES.md. The pmwd/ directory is kept
as a reference implementation for validation.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pmpp-0.1.3.tar.gz.
File metadata
- Download URL: pmpp-0.1.3.tar.gz
- Upload date:
- Size: 121.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3bc18a353d9deb10595f7df9807f00714488d8cc6a1c1bfc623f4aa596293354
|
|
| MD5 |
9c5e8852e4c3714311a09d67fce4cc48
|
|
| BLAKE2b-256 |
d1421265a29dcb16ea4e18321aafaf18ea09f816928f549e17e2ac576e6a8e71
|
Provenance
The following attestation bundles were made for pmpp-0.1.3.tar.gz:
Publisher:
publish-to-pypi.yml on rouzib/PMpp
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
pmpp-0.1.3.tar.gz -
Subject digest:
3bc18a353d9deb10595f7df9807f00714488d8cc6a1c1bfc623f4aa596293354 - Sigstore transparency entry: 2040506946
- Sigstore integration time:
-
Permalink:
rouzib/PMpp@94bc7de8259bb76029c42dcbc1834f2800ee604e -
Branch / Tag:
refs/heads/master - Owner: https://github.com/rouzib
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@94bc7de8259bb76029c42dcbc1834f2800ee604e -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file pmpp-0.1.3-py3-none-any.whl.
File metadata
- Download URL: pmpp-0.1.3-py3-none-any.whl
- Upload date:
- Size: 141.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7d2952d70e6b817ab9bce9c5f6d937b7f58057fdf57fbace1cfb5abaeddf0c41
|
|
| MD5 |
6d386bf74929a8323cf97afd10dd0366
|
|
| BLAKE2b-256 |
7a827b340e831e48b06807869d2f4d536161dc24469da0499336c93d7c2ca72a
|
Provenance
The following attestation bundles were made for pmpp-0.1.3-py3-none-any.whl:
Publisher:
publish-to-pypi.yml on rouzib/PMpp
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
pmpp-0.1.3-py3-none-any.whl -
Subject digest:
7d2952d70e6b817ab9bce9c5f6d937b7f58057fdf57fbace1cfb5abaeddf0c41 - Sigstore transparency entry: 2040507094
- Sigstore integration time:
-
Permalink:
rouzib/PMpp@94bc7de8259bb76029c42dcbc1834f2800ee604e -
Branch / Tag:
refs/heads/master - Owner: https://github.com/rouzib
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@94bc7de8259bb76029c42dcbc1834f2800ee604e -
Trigger Event:
workflow_dispatch
-
Statement type: