Skip to main content

Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox.

Project description

mushin logo

mushin

CI PyPI Python versions License: MIT

Boilerplate-free, reproducible machine-learning experiment workflows built on PyTorch Lightning and hydra-zen.

mushin is a standalone carve-out of the rai_toolbox.mushin subpackage from MIT Lincoln Laboratory's responsible-ai-toolbox. The upstream toolbox is no longer maintained (last release May 2023), but the mushin workflow layer still works against current versions of its dependencies. This package extracts just that layer so it can be maintained and used on its own.

Quickstart: run a sweep, get a dataset

Define your experiment as a function, sweep over parameters, and get the results back as a labeled xarray.Dataset — not rows in a dashboard you have to export.

import torch as tr
from mushin import multirun
from mushin.workflows import MultiRunMetricsWorkflow

class LRSweep(MultiRunMetricsWorkflow):
    @staticmethod
    def task(lr: float, seed: int) -> dict:
        tr.manual_seed(seed)
        # ... train a model with this lr/seed, then evaluate it ...
        acc = ...  # your validation accuracy
        return dict(accuracy=acc)  # whatever you return becomes a data variable

wf = LRSweep()
wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2]))  # 9 runs

ds = wf.to_xarray()
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
#   Data variables: accuracy (lr, seed)

ds["accuracy"].mean("seed")   # average over seeds, per learning rate

The full runnable version is in examples/sweep_to_dataset.py:

uv run python examples/sweep_to_dataset.py

Compare methods, with statistics

Evaluate trained models on a standard battery and get a labeled dataset plus significance — metrics delegated to torchmetrics, statistics to scipy:

from mushin.benchmark import compare

result = compare(
    methods={"ours": [m0, m1, m2], "baseline": [b0, b1, b2]},  # one trained model per seed
    data=test_loader, task="classification", num_classes=10, test="welch",
)

result.summary()       # mean ± CI per method, with significance markers — paper-ready
result.comparisons     # tidy DataFrame: pairwise p-values + effect sizes
result.data            # the labeled xarray (method × seed) to slice and plot

Don't have the trained models in memory yet? Study runs the multi-seed training sweep (via Hydra) and feeds the results straight into compare — define → train → evaluate → report in one call:

from mushin import Study

study = Study(
    methods={"cnn": train_cnn, "mlp": train_mlp},   # train_fn(seed) -> checkpoint path
    load_fn=LitClassifier.load_from_checkpoint,       # path -> model
    seeds=[0, 1, 2], data=test_loader, num_classes=10, test="welch",
)
result = study.run()                                  # -> BenchmarkResult

# ...or compare checkpoints you already have, no training:
Study.from_checkpoints(
    checkpoints={"cnn": ["cnn_0.ckpt", ...], "mlp": ["mlp_0.ckpt", ...]},
    load_fn=LitClassifier.load_from_checkpoint,
    data=test_loader, num_classes=10, test="welch",
).run()

What it provides

  • benchmark.compare — run a standard metric battery (torchmetrics) across trained seeds and get a labeled dataset + significance (scipy): BenchmarkResult with .summary(), .comparisons, and .data.
  • Study — orchestrate a multi-seed training sweep and route the trained models into compare, in one call; Study.from_checkpoints(...) for eval-only.
  • BaseWorkflow, MultiRunMetricsWorkflow, RobustnessCurve — declarative, reproducible experiment workflows that record configs, checkpoints, and metrics, and load results back as labeled xarray datasets.
  • MetricsCallback — a Lightning callback for capturing metrics.
  • HydraDDP — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
  • multirun, hydra_list, load_experiment, load_from_checkpoint — helpers.

Install

pip install mushin-py

Already use uv? uv pip install mushin-py (or uv add mushin-py inside a project) is faster.

Install name vs. import name: the PyPI distribution is mushin-py, but you import mushin (same pattern as scikit-learnsklearn).

Optional runtime extras: viz (matplotlib, for RobustnessCurve plotting) and netcdf (netCDF4) — e.g. pip install "mushin-py[viz]".

For a development environment (runtime deps + dev tooling), this project uses uv: uv sync.

Develop

uv run pytest tests/ --hypothesis-profile fast   # tests (DDP test needs >=2 GPUs)
uv run ruff check .                              # lint
uv run ruff format .                             # format
uv run codespell src tests                       # spell check

Or use the make shortcuts (make help to list them): make check runs lint + format-check + spell + tests (what CI runs); make test-py PYTHON=3.12 runs the suite on a specific Python version.

Supported Python versions: 3.9 – 3.14.

Relationship to upstream

This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration engine it depends on, hydra-zen, is actively maintained by the same group. See LICENSE.txt for attribution; the original MIT copyright is retained.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mushin_py-0.2.0.tar.gz (43.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mushin_py-0.2.0-py3-none-any.whl (40.8 kB view details)

Uploaded Python 3

File details

Details for the file mushin_py-0.2.0.tar.gz.

File metadata

  • Download URL: mushin_py-0.2.0.tar.gz
  • Upload date:
  • Size: 43.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mushin_py-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c35a447da1e50e7d60aa95067acda48d1b58bd3b2a6f92176a43c1e9e0bbcae7
MD5 68e8126f3d110c195c91522c47ebb96e
BLAKE2b-256 d01d3d911de67ca9167987597e716696ea1b769614fd54a3c85d7e87c7b89ab0

See more details on using hashes here.

Provenance

The following attestation bundles were made for mushin_py-0.2.0.tar.gz:

Publisher: publish.yml on martinez-hub/mushin

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mushin_py-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mushin_py-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 40.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mushin_py-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6a2e174c04ac5081c695a3727dbfef9fc12ca6c1e8a4da29618a29b6189f9806
MD5 3af8050b45337de0b9779e024ec60a96
BLAKE2b-256 f293be49f6b1cf5ecf55bd626279bd01d935b742daae0d9d128ab543ac7160b3

See more details on using hashes here.

Provenance

The following attestation bundles were made for mushin_py-0.2.0-py3-none-any.whl:

Publisher: publish.yml on martinez-hub/mushin

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page