Skip to main content

Multi-site Optimal-transport Shift Alignment with Interval Calibration for clinical data harmonization

Project description

MOSAIC

Multi-site Optimal-transport Shift Alignment with Interval Calibration

PyPI Python License

MOSAIC is a Python package for harmonizing clinical tabular data collected across multiple sites. It combines 1-D optimal transport for distribution alignment, anchor regression for domain-robust prediction, and weighted conformal inference for uncertainty quantification. The three components can be used independently or chained through a single pipeline.

The package was developed for multi-center IVF (in vitro fertilization) outcome prediction, but the methods are general and apply to any multi-site clinical or biomedical dataset with batch effects.

Overview

MOSAIC has three tiers, each usable on its own:

Tier Class What it does
1. Harmonization OTHarmonizer Per-feature quantile-based optimal transport mapping to a reference distribution. Reduces cross-center distribution shift while preserving within-center rank order.
2. Robust learning AnchorEstimator Wraps any sklearn estimator with anchor regression (via anchorboosting) or V-REx reweighting. Penalizes predictions that rely on center-specific patterns.
3. Uncertainty ConformalCalibrator Split conformal prediction with optional covariate-shift correction (Tibshirani et al., NeurIPS 2019). Produces prediction intervals (regression) or prediction sets (classification) with finite-sample coverage guarantees.

MOSAICPipeline chains all three into a single fit / predict interface.

Installation

Core (OT harmonization + anchor regression + conformal):

pip install mosaic-harmonize

With all optional dependencies (LightGBM, anchorboosting, MAPIE, matplotlib):

pip install mosaic-harmonize[full]

Individual extras: boosting, conformal, viz. For development: dev.

Requires Python 3.9+.

Quick start

Full pipeline

from mosaic import MOSAICPipeline
from lightgbm import LGBMRegressor

pipe = MOSAICPipeline(
    harmonizer="ot",
    robust_learner="anchor",
    uncertainty="weighted_conformal",
    base_estimator=LGBMRegressor(),
)

# center_ids: array of site labels, one per row
pipe.fit(X_train, y_train, center_ids=train_centers)

result = pipe.predict(X_test, center_id="new_hospital")
print(result.prediction)       # point predictions
print(result.lower, result.upper)  # 90% prediction intervals

Individual components

from mosaic import OTHarmonizer, AnchorEstimator, ConformalCalibrator

# Tier 1: align distributions
ot = OTHarmonizer(n_quantiles=1000, reference="global")
X_harmonized = ot.fit_transform(X_train, center_ids=train_centers)

# Inspect shift reduction
print(ot.wasserstein_distances())
print(ot.feature_shift_report())

# Tier 2: train a domain-robust model
anchor = AnchorEstimator(base_estimator=LGBMRegressor(), task_type="regression")
anchor.fit(X_harmonized, y_train, anchors=train_centers)
print(f"Best gamma: {anchor.best_gamma_}")
print(f"Cross-center stability: {anchor.stability_score_:.3f}")

# Tier 3: calibrate with conformal prediction
cal = ConformalCalibrator(method="weighted", alpha=0.10)
cal.calibrate(anchor, X_cal, y_cal, X_test=X_test)
result = cal.predict(X_test)
print(f"Interval widths: {(result.upper - result.lower).mean():.2f}")

Save and load

pipe.save("model.mosaic")
pipe = MOSAICPipeline.load("model.mosaic")

Register a new center at inference time

pipe.register_center("hospital_B", X_new_center)
result = pipe.predict(X_query, center_id="hospital_B")

API reference

OTHarmonizer

Parameter Type Default Description
n_quantiles int 1000 Number of quantile points for the OT map
features list[str] or None None Columns to harmonize (None = all numeric)
reference str "global" Reference distribution: "global" or a center name
min_samples int 50 Minimum non-null samples to build a map

Methods: fit(X, center_ids), transform(X, center_id=..., center_ids=...), fit_transform(X, center_ids), wasserstein_distances(), feature_shift_report().

AnchorEstimator

Parameter Type Default Description
base_estimator sklearn estimator or None None Base learner (None = Ridge/LogisticRegression)
gammas list[float] or None [1.5, 3.0, 7.0] Anchor penalty strengths to search
task_type str "auto" "auto", "regression", "binary", or "multiclass"
n_vrex_rounds int 5 V-REx reweighting iterations (fallback mode)

Methods: fit(X, y, anchors, X_val=None, y_val=None), predict(X), predict_proba(X). Properties: best_gamma_, stability_score_.

ConformalCalibrator

Parameter Type Default Description
method str "weighted" "weighted", "standard", or "lac"
alpha float 0.10 Miscoverage level (0.10 = 90% target coverage)

Methods: calibrate(model, X_cal, y_cal, X_test=None), predict(X_test) returning ConformalResult.

MOSAICPipeline

Parameter Type Default Description
harmonizer str or None "ot" "ot" or None
robust_learner str or None "anchor" "anchor" or None
uncertainty str or None "weighted_conformal" "weighted_conformal", "standard", "lac", or None
base_estimator sklearn estimator None Base learner passed to AnchorEstimator

Methods: fit(X_train, y_train, center_ids, X_cal=None, y_cal=None), predict(X, center_id=..., center_ids=...), register_center(name, X_new), diagnose(X, center_id), save(path), load(path).

Benchmarks

Evaluated on a multi-center IVF dataset (334K rows, 5 centers, 15 prediction targets). Full results in benchmarks/results/.

Ablation (Exp 1): each tier adds value

Target Baseline R² +OT +OT+Anchor Full MOSAIC
HCG_Day_E2 -0.665 0.127 0.229 0.229
egg_num 0.210 0.205 0.352 0.352
HCG_Day_Endo -0.775 -0.177 0.120 0.120

OT corrects distribution shift (E2: R² from -0.67 to 0.13). Anchor regression adds further gains for regression targets (egg_num: 0.21 to 0.35).

Cross-center generalization gap (Exp 2)

On an external test center unseen during training, MOSAIC reduces the validation-to-test performance gap by 42-76% for high-shift features (HCG_Day_E2: 75%, HCG_Day_P: 52%, HCG_Day_Endo: 72%).

Conformal coverage (Exp 3)

All 11 regression targets achieve 81-93% empirical coverage at the 90% nominal level. Weighted conformal consistently produces narrower intervals than standard split conformal at comparable coverage.

Comparison with existing methods (Exp 5)

Feature No harmonization Z-score ComBat MOSAIC (OT)
HCG_Day_E2 (R²) -0.665 0.069 -0.383 0.127
Clinical_pregnancy (AUC) 0.838 0.837 0.837 0.839
total_Gn (R²) -0.121 -0.327 -0.034 -0.022

MOSAIC outperforms Z-score and ComBat on high-shift features while maintaining comparable performance on low-shift targets.

Citation

If you use MOSAIC in your research, please cite:

@article{chen2026mosaic,
  title={MOSAIC: Multi-site Optimal-transport Shift Alignment with Interval
         Calibration for Clinical Data Harmonization},
  author={Chen, Peigen},
  journal={npj Digital Medicine},
  year={2026},
  note={Manuscript in preparation}
}

License

Apache-2.0. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mosaic_harmonize-0.1.0.tar.gz (27.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mosaic_harmonize-0.1.0-py3-none-any.whl (22.3 kB view details)

Uploaded Python 3

File details

Details for the file mosaic_harmonize-0.1.0.tar.gz.

File metadata

  • Download URL: mosaic_harmonize-0.1.0.tar.gz
  • Upload date:
  • Size: 27.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for mosaic_harmonize-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a1f90ba0f81c13728a4443400594025861a0bcbfd660bfb639d51eafb2298148
MD5 807ae8dcd8ad5ff9650826fcf4a2286f
BLAKE2b-256 26bb74515ae3cee0e96fba5492ce18ad3866d7aac6c4a73752b81e084d427211

See more details on using hashes here.

File details

Details for the file mosaic_harmonize-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mosaic_harmonize-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d7c2bca4790c097903271d30a9f2ffa93a73b048fdc0a1135853b871196e681f
MD5 890172cb6045607defbfb51677c447cd
BLAKE2b-256 398b5579531ded477472d7268988574d97d5b53a567bbe8b3f52d72ef36a15a7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page