Skip to main content

OpenMM plugin for exported JAX/XLA models

Project description

OpenMM-JAX

OpenMM plugin for running JAX force fields through JaxForce. For pre-trained machine learning force field models, see bio-mlff.

pip install openmmjax-cu12
pip install openmmjax-cu13

Using JaxForce

import jax
import jax.numpy as jnp
from openmmjax import JaxForce
from openmmjax_export import configure_pjrt_plugin, export_jax_model

# Separate functions for energy, forces, energy + forces for computational efficiency.
def compute_energy(positions, box_vectors):
    box_size = jnp.diag(box_vectors)
    wrapped = positions - jnp.floor(positions / box_size) * box_size
    return jnp.sum(wrapped**2)

def compute_forces(positions, box_vectors):
    return -jax.grad(compute_energy)(positions, box_vectors)

def compute_energy_and_forces(positions, box_vectors):
    energy, gradients = jax.value_and_grad(compute_energy)(positions, box_vectors)
    return energy, -gradients

# Apply to particles at indices 0 and 2
selected = [0, 2]

force_mlir, energy_mlir, energy_and_forces_mlir, compile_options = export_jax_model(
    num_model_atoms=len(selected), # needed for shape specialization
    force_function=compute_forces,
    energy_function=compute_energy,
    energy_and_forces_function=compute_energy_and_forces,
    periodic=True,
)

force = JaxForce(force_mlir, energy_mlir, energy_and_forces_mlir, compile_options)

# If PBC is turned off then remove box vectors as input
force.setUsesPeriodicBoundaryConditions(True)
force.setParticles(selected)
configure_pjrt_plugin(force)

Building from Source

git clone https://github.com/mitkotak/openmm-jax.git
cd openmm-jax
micromamba create -f environment.yml
micromamba activate openmm-jax
cmake -S . -B build \
  -DOPENMM_DIR="$CONDA_PREFIX" \
  -DCMAKE_INSTALL_PREFIX="$CONDA_PREFIX"
cmake --build build --target install --parallel
cmake --build build --target PythonInstall --parallel

Design Notes

  • Most of the frontend is directly borrowed from openmm-torch with the following main changes:

    • JaxForce expects separate functions for energy, forces and energy + forces to export instead of relying on a general energy + forces function. This saves up compute time when OpenMM requests only energy or forces.
    • Instead of compiling and storing the checkpoints on disk (for e.g. .pt, .hlo), the exported functions are converted to .mlir strings and then converted to PJRT executables loaded at runtime (PjrtRuntime::initialize / compileStablehloExecutable). This avoids creating extra files during testing, but support can be added if needed.
  • For the backend the key code complexity is in managing memory ownership and stream synchronization in moving from OpenMM to PJRT and back to OpenMM. The OpenMM to PJRT handoff is relatively straightforward since OpenMM owns all the memory until the handoff. To avoid stream syncs or D2D copies, a CUDA event coordinates the input handoff. The PJRT to OpenMM handoff on the other hand is much more complicated since PJRT now owns the memory. The output pointer is extracted from PJRT, followed by launching OpenMM's addForce kernel against it which adds the PJRT output to its global state and then deferring releasing the PJRT buffer until the kernel has finished (OpenMmPjrtOutputLifetime.h/.cpp). This part of the design was derived from our understanding of PJRT_Event_Await so if there's other APIs in the PJRT infrastructure that we missed let us know.

  • There's PJRT boilerplate for loading plugins (PjrtPlugin.h/.cpp), managing PJRT client sessions (PjrtClientSession.h/.cpp), wrapping device-buffer interop (PjrtBufferInterop.h/.cpp), and compiling/executing loaded executables (PjrtLoadedExecutable.h/.cpp). A lot of this machinery is borrowed from the PJRT C++ API which unfortunately comes with a heavy XLA build if we depend directly in it. This is why we directly copy PJRT's C API (pjrt_c_api.h) which is header only. There are also RAII style guards over PJRT handles (PjrtHandles.h/cpp) and CUDA contexts (CudaPrimaryContextGuard.h)

Acknowledgements

https://github.com/openmm/openmm/issues/4594 for the idea

@abhijeetgangan for discussions on API design, openmm-torch and openmm for the MD API, xla for PJRT code,PJRT tutorial

Also show some love to our friends at lammps-jax.

License

This project is licensed under the MIT License. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

openmmjax_cu12-0.1.3.1-cp312-cp312-manylinux_2_34_x86_64.whl (131.6 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ x86-64

openmmjax_cu12-0.1.3.1-cp311-cp311-manylinux_2_34_x86_64.whl (131.1 kB view details)

Uploaded CPython 3.11manylinux: glibc 2.34+ x86-64

File details

Details for the file openmmjax_cu12-0.1.3.1-cp312-cp312-manylinux_2_34_x86_64.whl.

File metadata

  • Download URL: openmmjax_cu12-0.1.3.1-cp312-cp312-manylinux_2_34_x86_64.whl
  • Upload date:
  • Size: 131.6 kB
  • Tags: CPython 3.12, manylinux: glibc 2.34+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.26 {"installer":{"name":"uv","version":"0.11.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for openmmjax_cu12-0.1.3.1-cp312-cp312-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 393a57f643b8298fad9c452f3b825d6c7792d244ba8597b7674cc1a773783b3c
MD5 dd4578caade0fabb58fb1d95ec89ee5a
BLAKE2b-256 12d34d88609556083e3425ebc3165da6f10fd3e9974f71e05c60c55ee824345b

See more details on using hashes here.

File details

Details for the file openmmjax_cu12-0.1.3.1-cp311-cp311-manylinux_2_34_x86_64.whl.

File metadata

  • Download URL: openmmjax_cu12-0.1.3.1-cp311-cp311-manylinux_2_34_x86_64.whl
  • Upload date:
  • Size: 131.1 kB
  • Tags: CPython 3.11, manylinux: glibc 2.34+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.26 {"installer":{"name":"uv","version":"0.11.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for openmmjax_cu12-0.1.3.1-cp311-cp311-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 b76e3518f5e2cffd8b22aee844bf4cd81d5f1732bbc8a444a7745b203e14e147
MD5 3c329e31938e0de518334de02111821b
BLAKE2b-256 7092da767f2bf3994dbb975537873ddb8a3f223eb62e93019856ec112470fd26

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page