Skip to main content

Machine Learning Interatomic Potentials in JAX

Project description

🪩 MLIP: Machine Learning Interatomic Potentials

uv Python 3.11 pre-commit Tests and Linters 🧪 badge

🚀 mlip v2 is now available

mlip v2 introduces a targeted API redesign focused on greater modularity, flexibility, and user control, alongside many new features and quality-of-life improvements across training, inference, and simulation workflows.

⚠️ Please note: v2 contains API changes that may require updates to existing codebases. We strongly recommend checking out the Migration Guide for upgrade instructions, breaking changes, and examples.

👀 Overview

mlip is a Python library for training and deploying Machine Learning Interatomic Potentials (MLIP) written in JAX. It provides the following functionalities:

  • 🧠 Multiple model architectures (currently: MACE, NequIP, eSEN, and ViSNet) and modular API for development of new architectures
  • 🧩 Mixture-of-Experts (MoE) formalism for the eSEN architecture
  • 📦 Dataset loading and preprocessing
  • 🎯 Training and fine-tuning MLIP models
  • 💨 Built-in distributed training for both GPU and TPU
  • ⚡ Batched inference with trained MLIP models
  • 🧪 MD simulations with MLIP models using multiple simulation backends (currently: JAX-MD and ASE)
  • 🌡️ Support for both NVT and NPT ensembles in MD
  • ⛰️ Energy minimizations using the same simulation backends as for MD
  • 🚀 Batched MD simulations and energy minimizations with the JAX-MD backend
  • 🔎 Transition state search with the nudged elastic band (NEB) method
  • 🌐 Global charge conditioning, partial charge predictions, and support for long-range interactions
  • 📈 Training on Hessian labels
  • ⚙️ Integration with the e3j backend for accelerated equivariant operations (currently in beta)
  • 🔬 To validate and benchmark your trained models, make sure to check out MLIPAudit

The purpose of the library is to provide users with a toolbox to deal with MLIP models in true end-to-end fashion. Hereby we follow the key design principles of (1) easy-of-use also for non-expert users that mainly care about applying pre-trained models to relevant biological or material science applications, (2) extensibility and flexibility for users more experienced with MLIP and JAX, and (3) a focus on high inference speeds that enable running long MD simulations on large systems which we believe is necessary in order to bring MLIP to large-scale industrial application. See our inference speed benchmark below.

🎙️ For further information on the design principles and story behind the mlip library, also check out our Let's Talk Research podcast episode on the topic.

See the Installation section for details on how to install mlip and the example Jupyter notebooks linked below for a quick way to get started. For detailed instructions, visit our extensive code documentation.

This repository currently supports implementations of:

📦 Installation

To install the regular CPU version of mlip via pip, use this command:

pip install mlip

We however recommend that the library is run on GPU. To install the CUDA 13 version of JAX and e3j binaries alongside mlip, run:

pip install "mlip[cuda13]"

To install the CUDA 12 version instead, run:

pip install "mlip[cuda12]"

mlip also defines "mlip[cuda13_local]" and "mlip[cuda12_local]" following the JAX naming patterns of pip extras. For any other custom versions, please install mlip without any CUDA flag, and refer to the installation guides for JAX and e3j.

We also support installation of the TPU version of JAX via this command:

pip install "mlip[tpu]"

⚡ Examples

In addition to the in-depth tutorials provided as part of our documentation here, we also provide example Jupyter notebooks that can be used as simple templates to build your own MLIP pipelines:

To run the tutorials, just install Jupyter notebooks via pip and launch it from a directory that contains the notebooks:

pip install notebook && jupyter notebook

The installation of mlip itself is included within the notebooks. We recommend to run these notebooks with GPU acceleration enabled.

Alternatively, we provide a Dockerfile in this repository that you can use to run the tutorial notebooks. This can be achieved by executing the following lines from any directory that contains the downloaded Dockerfile:

docker build . -t mlip_tutorials
docker run -p 8888:8888 --gpus all mlip_tutorials

Note that this will only work on machines with NVIDIA GPUs. Once running, you can access the Jupyter notebook server by clicking on the URL displayed in the console of the form "http://127.0.0.1:8888/tree?token=abcdef...".

🤗 Pre-trained models (via HuggingFace)

We have prepared pre-trained models trained on a curated version of the SPICE2 subset of OMol25 for each of the models included in this repo. They can be accessed directly on InstaDeep's MLIP collection, along with our curated dataset or directly through the huggingface-hub Python API:

from huggingface_hub import hf_hub_download

hf_hub_download(repo_id="InstaDeepAI/mlip_models_organics_v2", filename="mace_organics_02.zip", local_dir="")
hf_hub_download(repo_id="InstaDeepAI/mlip_models_organics_v2", filename="visnet_organics_02.zip", local_dir="")
hf_hub_download(repo_id="InstaDeepAI/mlip_models_organics_v2", filename="nequip_organics_02.zip", local_dir="")
hf_hub_download(repo_id="InstaDeepAI/mlip_models_organics_v2", filename="esen_organics_02.zip", local_dir="")
hf_hub_download(repo_id="InstaDeepAI/SPICE2_curated_v2", filename="SPICE2_curated_v2.zip", local_dir="")

Note that the pre-trained models are released on a different license than this library, please refer to the model cards of the relevant HuggingFace repos.

🚀 Inference time benchmarks

To showcase the runtime efficiency, we conducted benchmarks across all four models on two different systems: Chignolin (1UAO, 138 atoms) and Alpha-bungarotoxin (1ABT, 1205 atoms), both run for 1 ns of MD simulation using the NVT ensemble on a H100 NVIDIA GPU. All these JAX-based model implementations are our own and should not be considered representative of the performance of the code developed by the original authors of the methods. In the table below, we compare our integrations with the JAX-MD and the ASE simulation engines, respectively. Further details can be found in our white paper (see below).

MACE (3,274,016 parameters):

Systems JAX-MD ASE
1UAO 2.4 ms/step 7.3 ms/step
1ABT 19.2 ms/step 43.8 ms/step

ViSNet (1,172,676 parameters):

Systems JAX-MD ASE
1UAO 1.9 ms/step 7.1 ms/step
1ABT 13.7 ms/step 30.2 ms/step

NequIP (1,327,792 parameters):

Systems JAX-MD ASE
1UAO 3.4 ms/step 8.9 ms/step
1ABT 22.0 ms/step 44.6 ms/step

eSEN (3,210,498 parameters):

Systems JAX-MD ASE
1UAO 3.0 ms/step 8.9 ms/step
1ABT 22.8 ms/step 46.7 ms/step

🙏 Acknowledgments

This work was supported by Cloud TPUs from Google’s TPU Research Cloud (TRC). We would also like to thank Bohan Cao from Nankai University / Zhongguancun Academy for the numerous suggestions and conversations.

📚 Citing our work

We kindly request that you to cite our white paper when using this library:

C. Brunken, T. Cormier, L. Walewski, M. Carobene, Y. Khanfir, Z. Weller-Davies, M. Bragança, A. Picard, A. Pichard, L. Wehrhan, H. Chomet, E. Varga-Umbrich, M. Bluntzer, M. Bortone, V. Heyraud, S. Acosta-Gutiérrez, J. Tilly, and O. Peltre, Machine Learning Interatomic Potentials: Advancing Open-Source Software for Efficient and Scalable Molecular Simulation, arXiv, 2026, arXiv:2605.22698.

The BibTeX formatted citation:

@misc{brunken2026machinelearninginteratomicpotentials,
      title={Machine Learning Interatomic Potentials: Advancing Open-Source Software for Efficient and Scalable Molecular Simulation},
      author={Christoph Brunken and Titouan Cormier and Lucien Walewski and Marco Carobene and Yessine Khanfir and Zachary Weller-Davies and Miguel Bragança and Armand Picard and Adrien Pichard and Leon Wehrhan and Heloise Chomet and Eszter Varga-Umbrich and Marie Bluntzer and Massimo Bortone and Valentin Heyraud and Silvia Acosta-Gutiérrez and Jules Tilly and Olivier Peltre},
      year={2026},
      eprint={2605.22698},
      archivePrefix={arXiv},
      primaryClass={physics.chem-ph},
      url={https://arxiv.org/abs/2605.22698},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlip-0.2.0.tar.gz (648.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlip-0.2.0-py3-none-any.whl (403.1 kB view details)

Uploaded Python 3

File details

Details for the file mlip-0.2.0.tar.gz.

File metadata

  • Download URL: mlip-0.2.0.tar.gz
  • Upload date:
  • Size: 648.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.16 {"installer":{"name":"uv","version":"0.11.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for mlip-0.2.0.tar.gz
Algorithm Hash digest
SHA256 06b63c19faa61f625e6ab42a94af855aab85cd678f6f766f12006bbc90a4d658
MD5 e949b675cf59f83b45891386ef836671
BLAKE2b-256 cfd99644fc74de291e6c55e1ce8334b9681a5a742aae64a8aa7158845d0592aa

See more details on using hashes here.

File details

Details for the file mlip-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mlip-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 403.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.16 {"installer":{"name":"uv","version":"0.11.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for mlip-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ae1d47a4afed469718bd3dceb0a6b71f867aa507fc382879827103018c6d4479
MD5 9ff714abb823d2572dad9997052e0083
BLAKE2b-256 1b7af58e675343c56d5174db74cd48944b5d602eeef9f4846d5125eb5038e57c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page