experimental tooling for training machine learning interatomic potentials in jax
Project description
marathon
modular training infrastructure for machine-learning interatomic potentials in JAX
pheidippides would be a great name for a message-passing neural network
marathon is an experimental jax/flax-oriented toolkit for prototyping machine-learning interatomic potentials. It does not provide a finished and polished training loop; instead it provides a few composable parts that can be assembled and adapted as needed for experiments. It's therefore not intended as user-facing production code, instead it aims to make experiments faster and more pleasant.
| Module | Description |
|---|---|
marathon.data |
Process ase.Atoms into padded batches with a flexible properties system |
marathon.evaluate |
Predict energy, forces, stress; compute loss (MSE/Huber) and metrics (MAE, RMSE, R2) |
marathon.emit |
Checkpointing, logging (text, W&B), diagnostic plots |
marathon.io |
Read/write msgpack and yaml; serialize flax.nn.Module instances |
marathon.elemental |
Per-element energy baselines via linear regression |
marathon.grain |
Scalable data pipelines with grain for large datasets |
marathon.extra.edge_to_edge |
Fixed-size neighborhood batching for PET-style edge transformers |
Since the library is aimed at active research and is used and adapted as needed, there is no documentation beyond README.md files at each module level explaining terminology, notation, and sometimes the idea behind the design of a subpackage. This avoids the risk of documentation and code going out of sync -- at the cost of requiring more code reading. (Luckily, the computers can do some of the reading nowadays...)
Anyhow, you are encouraged to fork and adapt marathon for your personal experiments. PRs with self-contained and reusable features are welcome.
Installation
The main dependency is jax; detailed installation instructions are here. Typically, the standard install works fairly well:
pip install "jax[cuda13]" # or jax[cpu] for CPU-only
pip install -e .
marathon provides a number of extras, installable via pip install -e ".[all]". They are required to run some parts of the code but not automatically installed to avoid dependency resolution hell in HPC systems.
pip install -e ".[grain]" # grain pipelines: grain, mmap_ninja, numba
pip install -e ".[dev]" # development: pytest, ruff
pip install -e ".[wandb]" # Weights & Biases logging
pip install -e ".[plot]" # plotting: matplotlib, scipy
For convenience, marathon looks for an environment variable named DATASETS and turns it into a Path at marathon.data.datasets. If the variable is not set, it defaults to the current working directory.
Quick start
For datasets that can be fully fit into (GPU) memory (do this ahead of time to fixed size and shuffle on GPU):
from marathon.data import to_sample, batch_samples, determine_max_sizes
samples = [to_sample(atoms, cutoff=5.0) for atoms in my_atoms]
num_atoms, num_pairs = determine_max_sizes(samples, batch_size=4)
batch = batch_samples(samples, num_atoms, num_pairs, keys=["energy", "forces"])
For large-scale training with grain pipelines (streaming data from disk through a series of transforms):
from marathon.grain import DataSource, DataLoader, IndexSampler, ToSample, ToFixedLengthBatch
ds = DataSource("path/to/prepared/dataset")
sampler = IndexSampler(len(ds), shuffle=True, seed=0)
loader = DataLoader(
data_source=ds,
sampler=sampler,
operations=[ToSample(cutoff=5.0), ToFixedLengthBatch(batch_size=4)],
)
Development
pip install -e ".[dev]"
ruff format . && ruff check --fix .
python -m pytest
Linting and formatting is done by ruff. We use a line length of 92, but it is not enforced by the linter, only by the formatter. This avoids hassle when lines can't be shortened automatically. We also suppress some rules that get in the way of research code: short variable names (E741), lambdas (E731), and non-top-level imports (E402). Import ordering groups numpy and jax before other third-party packages.
The code itself tends towards concise and functional: descriptive names, minimal docstrings (only where behaviour isn't obvious from context), and liberal use of lambdas and comprehensions. Many modules include inline tests at the bottom that run on import.
Logo designed by overripemango.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file marathon_train-0.2.1.tar.gz.
File metadata
- Download URL: marathon_train-0.2.1.tar.gz
- Upload date:
- Size: 45.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2cda712c1d02f2ac0bf9c163eeb3030e2ad8aa607e465d1511d2526fa04c45fc
|
|
| MD5 |
797dd48bb826f07fd5a3846b9b7da0d8
|
|
| BLAKE2b-256 |
02fb479278f1893e2e7c5bfc2aba012d7e059e3020eab24848f740dbb3f19d3d
|
Provenance
The following attestation bundles were made for marathon_train-0.2.1.tar.gz:
Publisher:
release.yml on sirmarcel/marathon
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
marathon_train-0.2.1.tar.gz -
Subject digest:
2cda712c1d02f2ac0bf9c163eeb3030e2ad8aa607e465d1511d2526fa04c45fc - Sigstore transparency entry: 1146440566
- Sigstore integration time:
-
Permalink:
sirmarcel/marathon@4dc51f159c8d486635597d92b35b172632d5f474 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/sirmarcel
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@4dc51f159c8d486635597d92b35b172632d5f474 -
Trigger Event:
push
-
Statement type:
File details
Details for the file marathon_train-0.2.1-py3-none-any.whl.
File metadata
- Download URL: marathon_train-0.2.1-py3-none-any.whl
- Upload date:
- Size: 59.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6df620e136ae93f52c41295b9c6e74edaf61fd170e42608402274daebf2509cc
|
|
| MD5 |
3d73e9abf6354d042cecc9ce0905167f
|
|
| BLAKE2b-256 |
6a25772801535fd78879e1d58715dfa9f7a43635b4623c9a12fb658f643d1746
|
Provenance
The following attestation bundles were made for marathon_train-0.2.1-py3-none-any.whl:
Publisher:
release.yml on sirmarcel/marathon
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
marathon_train-0.2.1-py3-none-any.whl -
Subject digest:
6df620e136ae93f52c41295b9c6e74edaf61fd170e42608402274daebf2509cc - Sigstore transparency entry: 1146440639
- Sigstore integration time:
-
Permalink:
sirmarcel/marathon@4dc51f159c8d486635597d92b35b172632d5f474 -
Branch / Tag:
refs/tags/v0.2.1 - Owner: https://github.com/sirmarcel
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@4dc51f159c8d486635597d92b35b172632d5f474 -
Trigger Event:
push
-
Statement type: