Skip to main content

Domain-robust PyTorch and sklearn: train on site A, deploy on site B with the same labels.

Project description

matching-pmh

Train on site A. Deploy on site B. Same labels.

Add a domain-robustness regularizer to your PyTorch model or sklearn pipeline — without replacing your architecture or reading a research paper first.

PyPI Python License: MIT CI

PyPI · GitHub · What is this? · First hour · Open In Colab


Developer API (start here)

from pmh import check_applicability, robust_fit, evaluate_baseline_vs_pmh

# Go / marginal / no-go before training
print(check_applicability(stack="pytorch", n_source=500, n_target=400).summary())

# PyTorch: one call
out = robust_fit(model, train_loader, source_batches=src, target_batches=tgt, hook="auto", epochs=20)

# sklearn: baseline vs PMH on target holdout
report = evaluate_baseline_vs_pmh(x_source, y_source, x_target, y_target, compare_to=("coral",))
print(report.summary())

Three paths only: GOLDEN_PATHS.md · pmh-train wizard · starter template


Who this is for

You have Start here
PyTorch model + source/target data PMHTrainer + nuisance="domain_shift" · Colab
Frozen .npy / sklearn features PMHMatcher in a Pipeline · Colab (sklearn)
LLM style/format drift NLP walkthrough

Not for: new test-time classes, unrelated label definitions, or “make any model robust to everything.”

When to use it / when not →


5-minute try

Open In Colab
or locally:

pip install matching-pmh torch
git clone https://github.com/vishalstark512/matching-pmh.git
cd matching-pmh
python examples/00_first_run_domain_shift.py

You get baseline vs PMH target accuracy on synthetic data, then copy the pattern into your project.

Expected output: docs/DEMO_OUTPUT.md

pmh-train wizard

Default integration (PyTorch)

from pmh import PMHTrainer, PMHConfig

trainer = PMHTrainer(
    model,
    hook=backbone,
    head=classifier,
    nuisance="domain_shift",
    pmh_config=PMHConfig.balanced(),
    artifact_path="artifacts/my_run.pt",
)
trainer.fit(
    train_loader,
    source_batches=source_loader,
    target_batches=target_loader,
    epochs=20,
)

sklearn Pipeline (frozen features)

You already have embeddings x_source, x_target (same dimension). Adapt, then classify — like any other preprocessing step:

pip install "matching-pmh[sklearn]"

from pmh import PMHMatcher
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

pipe = Pipeline([
    ("adapt", PMHMatcher(nuisance="domain_shift").fit(x_source, x_target)),
    ("clf", LogisticRegression(max_iter=500)),
])
pipe.fit(x_source, y_source)
# Score on held-out TARGET rows — not source-only accuracy

Gallery: tabular · examples/06_office31_sklearn.py


How it works (30 seconds)

flowchart LR
  A[Site A data] --> E[Estimate shift geometry once]
  B[Site B data] --> E
  E --> T[Train your model]
  T --> H[Representation hook h]
  H --> L[Task loss + robustness penalty]

Same hook layer for estimate and train. First hour · Getting started · Troubleshooting glossary


How it relates to CORAL / domain adaptation

PMH keeps your task loss and penalizes sensitivity in representation space along directions that differ between train and deploy environments. See vs CORAL.

Under the hood: estimate deployment geometry once, train with an extra matched penalty on hook h. Details are optional: Theory.


Documentation map

I want to… Read
Understand in plain language WHAT_IS_PMH.md
Run + copy code in one hour FIRST_HOUR.md
Integrate on my repo GETTING_STARTED.md
ResNet / ViT / HF hooks hooks.md · Gallery
Prove claims before production Controls walkthrough
Replicate paper benchmarks CORRECT_USAGE · Paper alignment

18 walkthroughs · Developer onboarding plan


Install

pip install matching-pmh
pip install "matching-pmh[sklearn]"   # classical ML path
pip install "matching-pmh[vision]"    # ResNet / timm examples
pip install "matching-pmh[hf]"        # LLM style shift

Advanced (estimators D1–D7, research)

The Grand Unification paper unifies many nuisance estimators (subspace, domain Gram, augmentations, style, …). The library exposes them when your shift story is not plain cross-domain:

pmh-train list-methods
pmh-train list-presets
Research / benchmark Doc
Paper block presets paper-presets-by-block.md
Office-31 reference table BENCHMARKS.md
Lemma-level math THEORY.md

Citation

@software{matching_pmh,
  title  = {matching-pmh: Matched PMH training from estimated deployment nuisance geometry},
  author = {Rajput, Vishal},
  year   = {2026},
  url    = {https://github.com/vishalstark512/matching-pmh}
}

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

matching_pmh-1.5.0.tar.gz (176.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

matching_pmh-1.5.0-py3-none-any.whl (97.3 kB view details)

Uploaded Python 3

File details

Details for the file matching_pmh-1.5.0.tar.gz.

File metadata

  • Download URL: matching_pmh-1.5.0.tar.gz
  • Upload date:
  • Size: 176.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.4

File hashes

Hashes for matching_pmh-1.5.0.tar.gz
Algorithm Hash digest
SHA256 0c44e2f8b377dfd9b455a8924e911262800dfa3988ff0eaf7ff6b4412add21d6
MD5 cc28cb1e3b431eed63b8af3d9dc46a69
BLAKE2b-256 0771d0ecb03809ea80608d99177c4604dfd0e6ec2b17455ba7da146e8df5989f

See more details on using hashes here.

File details

Details for the file matching_pmh-1.5.0-py3-none-any.whl.

File metadata

  • Download URL: matching_pmh-1.5.0-py3-none-any.whl
  • Upload date:
  • Size: 97.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.4

File hashes

Hashes for matching_pmh-1.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 89753ba0c96aacd082edca815ee81d54373fb8ffebe85c83b5f5fcdd679cdb4b
MD5 c59fcb7781d231d2efe38fed727a9b28
BLAKE2b-256 1e2f8f4d6961dc67f1ed53b077ec899fdd859d44b510355cc96bba24688958c9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page