Skip to main content

Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox.

Project description

mushin logo

mushin

CI PyPI Python versions License: MIT

Boilerplate-free, reproducible machine-learning experiment workflows built on PyTorch Lightning and hydra-zen.

mushin is a standalone carve-out of the rai_toolbox.mushin subpackage from MIT Lincoln Laboratory's responsible-ai-toolbox. The upstream toolbox is no longer maintained (last release May 2023), but the mushin workflow layer still works against current versions of its dependencies. This package extracts just that layer so it can be maintained and used on its own.

Quickstart: run a sweep, get a dataset

Define your experiment as a function, sweep over parameters, and get the results back as a labeled xarray.Dataset — not rows in a dashboard you have to export.

import torch as tr
from mushin import multirun
from mushin.workflows import MultiRunMetricsWorkflow

class LRSweep(MultiRunMetricsWorkflow):
    @staticmethod
    def task(lr: float, seed: int) -> dict:
        tr.manual_seed(seed)
        # ... train a model with this lr/seed, then evaluate it ...
        acc = ...  # your validation accuracy
        return dict(accuracy=acc)  # whatever you return becomes a data variable

wf = LRSweep()
wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2]))  # 9 runs

ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
#   Data variables: accuracy (lr, seed)

ds["accuracy"].mean("seed")   # average over seeds, per learning rate

The full runnable version is in examples/sweep_to_dataset.py:

uv run python examples/sweep_to_dataset.py

Compare methods, with statistics

Evaluate trained models on a standard battery and get a labeled dataset plus significance — metrics delegated to torchmetrics, statistics to scipy:

from mushin.benchmark import compare

result = compare(
    methods={"ours": [m0, m1, m2], "baseline": [b0, b1, b2]},  # one trained model per seed
    data=test_loader, task="classification", num_classes=10, test="welch",
)

result.summary()       # mean ± CI per method, with significance markers — paper-ready
result.comparisons     # tidy DataFrame: pairwise p-values + effect sizes
result.data            # the labeled xarray (method × seed) to slice and plot

Don't have the trained models in memory yet? Study runs the multi-seed training sweep (via Hydra) and feeds the results straight into compare — define → train → evaluate → report in one call:

from mushin import Study

study = Study(
    methods={"cnn": train_cnn, "mlp": train_mlp},   # train_fn(seed) -> checkpoint path
    load_fn=LitClassifier.load_from_checkpoint,       # path -> model
    seeds=[0, 1, 2], data=test_loader, num_classes=10, test="welch",
)
result = study.run()                                  # -> BenchmarkResult

# ...or compare checkpoints you already have, no training:
Study.from_checkpoints(
    checkpoints={"cnn": ["cnn_0.ckpt", ...], "mlp": ["mlp_0.ckpt", ...]},
    load_fn=LitClassifier.load_from_checkpoint,
    data=test_loader, num_classes=10, test="welch",
).run()

What it provides

  • benchmark.compare — run a standard metric battery (torchmetrics) across trained seeds and get a labeled dataset + significance (scipy): BenchmarkResult with .summary(), .comparisons, and .data.
  • Study — orchestrate a multi-seed training sweep and route the trained models into compare, in one call; Study.from_checkpoints(...) for eval-only.
  • BaseWorkflow, MultiRunMetricsWorkflow, RobustnessCurve — declarative, reproducible experiment workflows that record configs, checkpoints, and metrics, and load results back as labeled xarray datasets.
  • MetricsCallback — a Lightning callback for capturing metrics.
  • HydraDDP — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
  • multirun, hydra_list, load_experiment, load_from_checkpoint — helpers.

Install

pip install mushin-py

Already use uv? uv pip install mushin-py (or uv add mushin-py inside a project) is faster.

Install name vs. import name: the PyPI distribution is mushin-py, but you import mushin (same pattern as scikit-learnsklearn).

Optional runtime extras: viz (matplotlib, for RobustnessCurve plotting) and netcdf (netCDF4) — e.g. pip install "mushin-py[viz]".

For a development environment (runtime deps + dev tooling), this project uses uv: uv sync.

Develop

uv run pytest tests/ --hypothesis-profile fast   # tests (DDP test needs >=2 GPUs)
uv run ruff check .                              # lint
uv run ruff format .                             # format
uv run codespell src tests                       # spell check

Or use the make shortcuts (make help to list them): make check runs lint + format-check + spell + tests (what CI runs); make test-py PYTHON=3.12 runs the suite on a specific Python version.

Supported Python versions: 3.9 – 3.14.

Relationship to upstream

This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration engine it depends on, hydra-zen, is actively maintained by the same group. See LICENSE.txt for attribution; the original MIT copyright is retained.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mushin_py-0.2.1.tar.gz (43.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mushin_py-0.2.1-py3-none-any.whl (40.8 kB view details)

Uploaded Python 3

File details

Details for the file mushin_py-0.2.1.tar.gz.

File metadata

  • Download URL: mushin_py-0.2.1.tar.gz
  • Upload date:
  • Size: 43.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mushin_py-0.2.1.tar.gz
Algorithm Hash digest
SHA256 f4f9eb407340cae8a1f3d8fe50f64978cd0a2bf549bd0f372942ad225dbc51d7
MD5 4af6d0a01c91a93b46281fc7e9a5dfb6
BLAKE2b-256 1bc7e625602039daec6fb8161efecda601821a5a4acc7c62f6f1880c00390ccb

See more details on using hashes here.

Provenance

The following attestation bundles were made for mushin_py-0.2.1.tar.gz:

Publisher: publish.yml on martinez-hub/mushin

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mushin_py-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: mushin_py-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 40.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mushin_py-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e6fdb5a743780c173625e25dc71c85fcf575a93bbc71efbebb4c9ce0a2f7f623
MD5 07fb8b334113c61b8c9c931b9f5caa0e
BLAKE2b-256 e002de68d19131d5641ac40cf51c366bc56ce34154471ba275ba84e1bcded6c1

See more details on using hashes here.

Provenance

The following attestation bundles were made for mushin_py-0.2.1-py3-none-any.whl:

Publisher: publish.yml on martinez-hub/mushin

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page