Skip to main content

experimental tooling for training machine learning interatomic potentials in jax

Project description

marathon

marathon

modular training infrastructure for machine-learning interatomic potentials in JAX

status: experimental code style: black

pheidippides would be a great name for a message-passing neural network


marathon is an experimental jax/flax-oriented toolkit for prototyping machine-learning interatomic potentials. It does not provide a finished and polished training loop; instead it provides a few composable parts that can be assembled and adapted as needed for experiments. It's therefore not intended as user-facing production code, instead it aims to make experiments faster and more pleasant.

Module Description
marathon.data Process ase.Atoms into padded batches with a flexible properties system
marathon.evaluate Predict energy, forces, stress; compute loss (MSE/Huber) and metrics (MAE, RMSE, R2)
marathon.emit Checkpointing, logging (text, W&B), diagnostic plots
marathon.io Read/write msgpack and yaml; serialize flax.nn.Module instances
marathon.elemental Per-element energy baselines via linear regression
marathon.grain Scalable data pipelines with grain for large datasets
marathon.extra.edge_to_edge Fixed-size neighborhood batching for PET-style edge transformers

Since the library is aimed at active research and is used and adapted as needed, there is no documentation beyond README.md files at each module level explaining terminology, notation, and sometimes the idea behind the design of a subpackage. This avoids the risk of documentation and code going out of sync -- at the cost of requiring more code reading. (Luckily, the computers can do some of the reading nowadays...)

Anyhow, you are encouraged to fork and adapt marathon for your personal experiments. PRs with self-contained and reusable features are welcome.

Installation

The main dependency is jax; detailed installation instructions are here. Typically, the standard install works fairly well:

pip install "jax[cuda13]"   # or jax[cpu] for CPU-only
pip install -e .

marathon provides a number of extras, installable via pip install -e ".[all]". They are required to run some parts of the code but not automatically installed to avoid dependency resolution hell in HPC systems.

pip install -e ".[grain]"   # grain pipelines: grain, mmap_ninja, numba
pip install -e ".[dev]"     # development: pytest, ruff
pip install -e ".[wandb]"   # Weights & Biases logging
pip install -e ".[plot]"    # plotting: matplotlib, scipy

For convenience, marathon looks for an environment variable named DATASETS and turns it into a Path at marathon.data.datasets. If the variable is not set, it defaults to the current working directory.

Quick start

For datasets that can be fully fit into (GPU) memory (do this ahead of time to fixed size and shuffle on GPU):

from marathon.data import to_sample, batch_samples, determine_max_sizes

samples = [to_sample(atoms, cutoff=5.0) for atoms in my_atoms]
num_atoms, num_pairs = determine_max_sizes(samples, batch_size=4)
batch = batch_samples(samples, num_atoms, num_pairs, keys=["energy", "forces"])

For large-scale training with grain pipelines (streaming data from disk through a series of transforms):

from marathon.grain import DataSource, DataLoader, IndexSampler, ToSample, ToFixedLengthBatch

ds = DataSource("path/to/prepared/dataset")
sampler = IndexSampler(len(ds), shuffle=True, seed=0)
loader = DataLoader(
    data_source=ds,
    sampler=sampler,
    operations=[ToSample(cutoff=5.0), ToFixedLengthBatch(batch_size=4)],
)

Development

pip install -e ".[dev]"
ruff format . && ruff check --fix .
python -m pytest

Linting and formatting is done by ruff. We use a line length of 92, but it is not enforced by the linter, only by the formatter. This avoids hassle when lines can't be shortened automatically. We also suppress some rules that get in the way of research code: short variable names (E741), lambdas (E731), and non-top-level imports (E402). Import ordering groups numpy and jax before other third-party packages.

The code itself tends towards concise and functional: descriptive names, minimal docstrings (only where behaviour isn't obvious from context), and liberal use of lambdas and comprehensions. Many modules include inline tests at the bottom that run on import.


Logo designed by overripemango.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

marathon_train-0.2.2.tar.gz (45.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

marathon_train-0.2.2-py3-none-any.whl (59.5 kB view details)

Uploaded Python 3

File details

Details for the file marathon_train-0.2.2.tar.gz.

File metadata

  • Download URL: marathon_train-0.2.2.tar.gz
  • Upload date:
  • Size: 45.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for marathon_train-0.2.2.tar.gz
Algorithm Hash digest
SHA256 18d7c5121a00b9dfa58db6227778bc4efe6c9bc9a3405766f56e8c6410f1ca4c
MD5 6066db318fac9f9a9600093a7ffa9e7b
BLAKE2b-256 07f5b8778721c07e7ef6c2de473aad165e4602426cf61ceb6f15c8bda62e2f48

See more details on using hashes here.

Provenance

The following attestation bundles were made for marathon_train-0.2.2.tar.gz:

Publisher: release.yml on sirmarcel/marathon

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file marathon_train-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: marathon_train-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 59.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for marathon_train-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f3010b133658a47c529ab3bfa721570ace44ca877de3e80ede49f2c8d5e740be
MD5 ca1909984318683b0b54762cdcc28888
BLAKE2b-256 4329775201f9b1201fcd4f4330601c0531d091110a239e8fd8e4de8378079e4b

See more details on using hashes here.

Provenance

The following attestation bundles were made for marathon_train-0.2.2-py3-none-any.whl:

Publisher: release.yml on sirmarcel/marathon

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page