Skip to main content

Differentiable, Hardware Accelerated, Molecular Dynamics

Project description

JAX, M.D.

Accelerated, Differentiable, Molecular Dynamics

Quickstart | Reference docs | Paper | NeurIPS 2020

Build DOI PyPI PyPI - License

Molecular dynamics is a workhorse of modern computational condensed matter physics. It is frequently used to simulate materials to observe how small scale interactions can give rise to complex large-scale phenomenology. Most molecular dynamics packages (e.g. HOOMD Blue or LAMMPS) are complicated, specialized pieces of code that are many thousands of lines long. They typically involve significant code duplication to allow for running simulations on CPU and GPU. Additionally, large amounts of code is often devoted to taking derivatives of quantities to compute functions of interest (e.g. gradients of energies to compute forces).

However, recent work in machine learning has led to significant software developments that might make it possible to write more concise molecular dynamics simulations that offer a range of benefits. Here we target JAX, which allows us to write python code that gets compiled to XLA and allows us to run on CPU, GPU, or TPU. Moreover, JAX allows us to take derivatives of python code. Thus, not only is this molecular dynamics simulation automatically hardware accelerated, it is also end-to-end differentiable. This should allow for some interesting experiments that we're excited to explore.

JAX, MD is a research project that is currently under development. Expect sharp edges and possibly some API breaking changes as we continue to support a broader set of simulations. JAX MD is a functional and data driven library. Data is stored in arrays or tuples of arrays and functions transform data from one state to another.

Getting Started

For a video introducing JAX MD along with a demo, check out this talk from the Physics meets Machine Learning series:

Science Meets ML Talk

To get started playing around with JAX MD check out the following colab notebooks on Google Cloud without needing to install anything. For a very simple introduction, I would recommend the Minimization example. For an example of a bunch of the features of JAX MD, check out the JAX MD cookbook.

JAX MD also comes with self contained python scripts which you run locally if you have JAX MD installed:

You can install JAX MD locally with pip,

pip install jax-md --upgrade

If you want to build the latest version then you can grab the most recent version from head,

git clone https://github.com/jax-md/jax-md
pip install -e jax-md

Overview

We now summarize the main components of the library.

Spaces (space.py)

In general we must have a way of computing the pairwise distance between atoms. We must also have efficient strategies for moving atoms in some space that may or may not be globally isomorphic to R^N. For example, periodic boundary conditions are commonplace in simulations and must be respected. Spaces are defined as a pair of functions, (displacement_fn, shift_fn). Given two points displacement_fn(R_1, R_2) computes the displacement vector between the two points. If you would like to compute displacement vectors between all pairs of points in a given (N, dim) matrix the function space.map_product appropriately vectorizes displacement_fn. It is often useful to define a metric instead of a displacement function in which case you can use the helper function space.metric to convert a displacement function to a metric function. Given a point and a shift shift_fn(R, dR) displaces the point R by an amount dR.

The following spaces are currently supported:

Example:

from jax_md import space
box_size = 25.0
displacement_fn, shift_fn = space.periodic(box_size)

Potential Energy (energy.py)

In the simplest case, molecular dynamics calculations are often based on a pair potential that is defined by a user. This then is used to compute a total energy whose negative gradient gives forces. One of the very nice things about JAX is that we get forces for free! The second part of the code is devoted to computing energies.

We provide the following classical potentials:

We also provide the following neural network potentials:

For finite-ranged potentials it is often useful to consider only interactions within a certain neighborhood. We include the _neighbor_list modifier to the above potentials that uses a list of neighbors (see below) for optimization.

Example:

import jax.numpy as np
from jax import random
from jax_md import energy, quantity
N = 1000
spatial_dimension = 2
key = random.PRNGKey(0)
R = random.uniform(key, (N, spatial_dimension), minval=0.0, maxval=1.0)
energy_fn = energy.lennard_jones_pair(displacement_fn)
print('E = {}'.format(energy_fn(R)))
force_fn = quantity.force(energy_fn)
print('Total Squared Force = {}'.format(np.sum(force_fn(R) ** 2)))

Dynamics (simulate.py, minimize.py)

Given an energy function and a system, there are a number of dynamics are useful to simulate. The simulation code is based on the structure of the optimizers found in JAX. In particular, each simulation function returns an initialization function and an update function. The initialization function takes a set of positions and creates the necessary dynamical state variables. The update function does a single step of dynamics to the dynamical state variables and returns an updated state.

We include a several different kinds of dynamics. However, there is certainly room to add more for e.g. constant strain simulations.

It is often desirable to find an energy minimum of the system. We provide two methods to do this. We provide simple gradient descent minimization. This is mostly for pedagogical purposes, since it often performs poorly. We additionally include the FIRE algorithm which often sees significantly faster convergence. Moreover a common experiment to run in the context of molecular dynamics is to simulate a system with a fixed volume and temperature.

We provide the following dynamics:

Example:

from jax_md import simulate
temperature = 1.0
dt = 1e-3
init, update = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, temperature)
state = init(key, R)
for _ in range(100):
  state = update(state)
R = state.position

Spatial Partitioning (partition.py)

In many applications, it is useful to construct spatial partitions of particles / objects in a simulation.

We provide the following methods:

Cell List Example:

from jax_md import partition

cell_size = 5.0
capacity = 10
cell_list_fn = partition.cell_list(box_size, cell_size, capacity)
cell_list_data = cell_list_fn.allocate(R)

Neighbor List Example:

from jax_md import partition

neighbor_list_fn = partition.neighbor_list(displacement_fn, box_size, cell_size)
neighbors = neighbor_list_fn.allocate(R) # Create a new neighbor list.

# Do some simulating....

neighbors = neighbors.update(R)  # Update the neighbor list without resizing.
if neighbors.did_buffer_overflow:  # Couldn't fit all the neighbors into the list.
  neighbors = neighbor_list_fn.allocate(R)  # So create a new neighbor list.

There are three different formats of neighbor list supported: Dense, Sparse, and OrderedSparse. Dense neighbor lists store neighbors in an (particle_count, neighbors_per_particle) array, Sparse neighbor lists store neighbors in a (2, total_neighbors) array of pairs, OrderedSparse neighbor lists are like Sparse neighbor lists, but they only store pairs such that i < j.

Development

JAX MD is under active development. We have very limited development resources and so we typically focus on adding features that will have high impact to researchers using JAX MD (including us). Please don't hesitate to open feature requests to help us guide development. We more than welcome contributions!

Technical gotchas

GPU

You must follow JAX's GPU installation instructions to enable GPU support.

64-bit precision

To enable 64-bit precision, set the respective JAX flag before importing jax_md (see the JAX guide), for example:

from jax.config import config
config.update("jax_enable_x64", True)

Publications

JAX MD has been used in the following publications. If you don't see your paper on the list, but you used JAX MD let us know and we'll add it to the list!

  1. Molecular Simulations with a Pretrained Neural Network and Universal Pairwise Force Fields. (J. Am. Chem. Soc. 2025)
    A. Kabylda, J. T. Frank, S. Suárez-Dou, A. Khabibrakhmanov, L. Medrano Sandonas, O. T. Unke, S. Chmiela, K.-R. Müller, and A. Tkatchenko
  2. Designing precise dynamical steady states in disordered networks. (Machine Learning: Science and Technology (2025))
    M. Berneman and D. Hexner
  3. Generalized design of sequence-ensemble-function relationships for intrinsically disordered proteins
    R. K. Krueger, M. P. Brenner, and K. Shrinivas
  4. Tuning colloidal reactions. (PRL 2024)
    R. K. Krueger, E. M. King, and M. P. Brenner
  5. Programming patchy particles for materials assembly design. (PNAS 2024)
    E. M. King, CX. Du, QZ. Zhu, S. S. Schoenholz, and M. P. Brenner
  6. LATTE: an atomic environment descriptor based on Cartesian tensor contractions. (arXiv 2024)
    F. Pellegrini, S. Gironcoli, E. Küçükbenli
  7. PySAGES: flexible, advanced sampling methods accelerated with GPUs. (npj Computational Materials 2024)
    P. F. Zubieta Rico, et al.
  8. Scaling deep learning for materials discovery (Nature 2023)
    A. Merchant, et al.
  9. LapTrack: linear assignment particle tracking with tunable metrics. (Bioinformatics 2023)
    Yohsuke T Fukai and Kyogo Kawaguchi
  10. A Differentiable Neural-Network Force Field for Ionic Liquids. (J. Chem. Inf. Model. 2022)
    H. Montes-Campos, J. Carrete, S. Bichelmaier, L. M. Varela, and G. K. H. Madsen
  11. Correlation Tracking: Using simulations to interpolate highly correlated particle tracks. (Phys. Rev. E. 2022)
    E. M. King, Z. Wang, D. A. Weitz, F. Spaepen, and M. P. Brenner
  12. Optimal Control of Nonequilibrium Systems Through Automatic Differentiation.
    M. C. Engel, J. A. Smith, and M. P. Brenner
  13. Graph Neural Networks Accelerated Molecular Dynamics. (J. Chem. Phys. 2022)
    Z. Li, K. Meidani, P. Yadav, and A. B. Farimani
  14. Gradients are Not All You Need.
    L. Metz, C. D. Freeman, S. S. Schoenholz, and T. Kachman
  15. Lagrangian Neural Network with Differential Symmetries and Relational Inductive Bias.
    R. Bhattoo, S. Ranu, and N. M. A. Krishnan
  16. Efficient and Modular Implicit Differentiation.
    M. Blondel, Q. Berthet, M. Cuturi, R. Frostig, S. Hoyer, F. Llinares-López, F. Pedregosa, and J.-P. Vert
  17. Learning neural network potentials from experimental data via Differentiable Trajectory Reweighting.
    (Nature Communications 2021)

    S. Thaler and J. Zavadlav
  18. Learn2Hop: Learned Optimization on Rough Landscapes. (ICML 2021)
    A. Merchant, L. Metz, S. S. Schoenholz, and E. D. Cubuk
  19. Designing self-assembling kinetics with differentiable statistical physics models. (PNAS 2021)
    C. P. Goodrich, E. M. King, S. S. Schoenholz, E. D. Cubuk, and M. P. Brenner

Citation

If you use the code in a publication, please cite the repo using the .bib,

@inproceedings{jaxmd2020,
 author = {Schoenholz, Samuel S. and Cubuk, Ekin D.},
 booktitle = {Advances in Neural Information Processing Systems},
 publisher = {Curran Associates, Inc.},
 title = {JAX M.D. A Framework for Differentiable Physics},
 url = {https://papers.nips.cc/paper/2020/file/83d3d4b6c9579515e1679aca8cbc8033-Paper.pdf},
 volume = {33},
 year = {2020}
}

If you use functionalities related to RigidBody, please cite the following paper using the .bib,

@article{king2024programming,
  title={Programming patchy particles for materials assembly design},
  author={King, Ella M and Du, Chrisy Xiyu and Zhu, Qian-Ze and Schoenholz, Samuel S and Brenner, Michael P},
  journal={Proceedings of the National Academy of Sciences},
  volume={121},
  number={27},
  pages={e2311891121},
  year={2024},
  publisher={National Academy of Sciences}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_md-0.2.28.tar.gz (16.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_md-0.2.28-py3-none-any.whl (233.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_md-0.2.28.tar.gz.

File metadata

  • Download URL: jax_md-0.2.28.tar.gz
  • Upload date:
  • Size: 16.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_md-0.2.28.tar.gz
Algorithm Hash digest
SHA256 de899bdcbf1c176271f4cfce4075876f2fe94ebc7855f5988a2b50807bb33e08
MD5 94b8eda625a11856935f21839c1ecec6
BLAKE2b-256 eea1e2eaeb39dc7838e810af1ead6226f392d780eac25c00f48640cecdc2476b

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_md-0.2.28.tar.gz:

Publisher: pythonpublish.yml on jax-md/jax-md

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_md-0.2.28-py3-none-any.whl.

File metadata

  • Download URL: jax_md-0.2.28-py3-none-any.whl
  • Upload date:
  • Size: 233.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_md-0.2.28-py3-none-any.whl
Algorithm Hash digest
SHA256 22964fb2620849a3cce25a404f7753e3c416590d5d91d92d0012061d00f6a0ba
MD5 c51730b81f58c48ff03c56e95989b6e0
BLAKE2b-256 f1f5328becb0d7e27f5447386f20b5a51e5865338a4899bf75a0f42548da4df0

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_md-0.2.28-py3-none-any.whl:

Publisher: pythonpublish.yml on jax-md/jax-md

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page