Skip to main content

experimental tooling for training machine learning interatomic potentials in jax

Project description

marathon

marathon

modular training infrastructure for machine-learning interatomic potentials in JAX

status: experimental code style: black

pheidippides would be a great name for a message-passing neural network


marathon is an experimental jax/flax-oriented toolkit for prototyping machine-learning interatomic potentials. It does not provide a finished and polished training loop; instead it provides a few composable parts that can be assembled and adapted as needed for experiments. It's therefore not intended as user-facing production code, instead it aims to make experiments faster and more pleasant.

Module Description
marathon.data Process ase.Atoms into padded batches with a flexible properties system
marathon.evaluate Predict energy, forces, stress; compute loss (MSE/Huber) and metrics (MAE, RMSE, R2)
marathon.emit Checkpointing, logging (text, W&B), diagnostic plots
marathon.io Read/write msgpack and yaml; serialize flax.nn.Module instances
marathon.elemental Per-element energy baselines via linear regression
marathon.grain Scalable data pipelines with grain for large datasets
marathon.extra.edge_to_edge Fixed-size neighborhood batching for PET-style edge transformers

Since the library is aimed at active research and is used and adapted as needed, there is no documentation beyond README.md files at each module level explaining terminology, notation, and sometimes the idea behind the design of a subpackage. This avoids the risk of documentation and code going out of sync -- at the cost of requiring more code reading. (Luckily, the computers can do some of the reading nowadays...)

Anyhow, you are encouraged to fork and adapt marathon for your personal experiments. PRs with self-contained and reusable features are welcome.

Installation

The main dependency is jax; detailed installation instructions are here. Typically, the standard install works fairly well:

pip install "jax[cuda13]"   # or jax[cpu] for CPU-only
pip install -e .

marathon provides a number of extras, installable via pip install -e ".[all]". They are required to run some parts of the code but not automatically installed to avoid dependency resolution hell in HPC systems.

pip install -e ".[grain]"   # grain pipelines: grain, mmap_ninja, numba
pip install -e ".[dev]"     # development: pytest, ruff
pip install -e ".[wandb]"   # Weights & Biases logging
pip install -e ".[plot]"    # plotting: matplotlib, scipy

For convenience, marathon looks for an environment variable named DATASETS and turns it into a Path at marathon.data.datasets. If the variable is not set, it defaults to the current working directory.

Quick start

For datasets that can be fully fit into (GPU) memory (do this ahead of time to fixed size and shuffle on GPU):

from marathon.data import to_sample, batch_samples, determine_max_sizes

samples = [to_sample(atoms, cutoff=5.0) for atoms in my_atoms]
num_atoms, num_pairs = determine_max_sizes(samples, batch_size=4)
batch = batch_samples(samples, num_atoms, num_pairs, keys=["energy", "forces"])

For large-scale training with grain pipelines (streaming data from disk through a series of transforms):

from marathon.grain import DataSource, DataLoader, IndexSampler, ToSample, ToFixedLengthBatch

ds = DataSource("path/to/prepared/dataset")
sampler = IndexSampler(len(ds), shuffle=True, seed=0)
loader = DataLoader(
    data_source=ds,
    sampler=sampler,
    operations=[ToSample(cutoff=5.0), ToFixedLengthBatch(batch_size=4)],
)

Development

pip install -e ".[dev]"
ruff format . && ruff check --fix .
python -m pytest

Linting and formatting is done by ruff. We use a line length of 92, but it is not enforced by the linter, only by the formatter. This avoids hassle when lines can't be shortened automatically. We also suppress some rules that get in the way of research code: short variable names (E741), lambdas (E731), and non-top-level imports (E402). Import ordering groups numpy and jax before other third-party packages.

The code itself tends towards concise and functional: descriptive names, minimal docstrings (only where behaviour isn't obvious from context), and liberal use of lambdas and comprehensions. Many modules include inline tests at the bottom that run on import.


Logo designed by overripemango.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

marathon_train-0.2.1.tar.gz (45.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

marathon_train-0.2.1-py3-none-any.whl (59.5 kB view details)

Uploaded Python 3

File details

Details for the file marathon_train-0.2.1.tar.gz.

File metadata

  • Download URL: marathon_train-0.2.1.tar.gz
  • Upload date:
  • Size: 45.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for marathon_train-0.2.1.tar.gz
Algorithm Hash digest
SHA256 2cda712c1d02f2ac0bf9c163eeb3030e2ad8aa607e465d1511d2526fa04c45fc
MD5 797dd48bb826f07fd5a3846b9b7da0d8
BLAKE2b-256 02fb479278f1893e2e7c5bfc2aba012d7e059e3020eab24848f740dbb3f19d3d

See more details on using hashes here.

Provenance

The following attestation bundles were made for marathon_train-0.2.1.tar.gz:

Publisher: release.yml on sirmarcel/marathon

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file marathon_train-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: marathon_train-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 59.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for marathon_train-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6df620e136ae93f52c41295b9c6e74edaf61fd170e42608402274daebf2509cc
MD5 3d73e9abf6354d042cecc9ce0905167f
BLAKE2b-256 6a25772801535fd78879e1d58715dfa9f7a43635b4623c9a12fb658f643d1746

See more details on using hashes here.

Provenance

The following attestation bundles were made for marathon_train-0.2.1-py3-none-any.whl:

Publisher: release.yml on sirmarcel/marathon

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page