Multi-site Optimal-transport Shift Alignment with Interval Calibration for clinical data harmonization
Project description
MOSAIC
Multi-site Optimal-transport Shift Alignment with Interval Calibration
MOSAIC is a Python package for harmonizing clinical tabular data collected across multiple sites. It combines 1-D optimal transport for distribution alignment, anchor regression for domain-robust prediction, and weighted conformal inference for uncertainty quantification. The three components can be used independently or chained through a single pipeline.
The package was developed for multi-center IVF (in vitro fertilization) outcome prediction, but the methods are general and apply to any multi-site clinical or biomedical dataset with batch effects.
Overview
MOSAIC has three tiers, each usable on its own:
| Tier | Class | What it does |
|---|---|---|
| 1. Harmonization | OTHarmonizer |
Per-feature quantile-based optimal transport mapping to a reference distribution. Reduces cross-center distribution shift while preserving within-center rank order. |
| 2. Robust learning | AnchorEstimator |
Wraps any sklearn estimator with anchor regression (via anchorboosting) or V-REx reweighting. Penalizes predictions that rely on center-specific patterns. |
| 3. Uncertainty | ConformalCalibrator |
Split conformal prediction with optional covariate-shift correction (Tibshirani et al., NeurIPS 2019). Produces prediction intervals (regression) or prediction sets (classification) with finite-sample coverage guarantees. |
MOSAICPipeline chains all three into a single fit / predict interface.
Installation
Core (OT harmonization + anchor regression + conformal):
pip install mosaic-harmonize
With all optional dependencies (LightGBM, anchorboosting, MAPIE, matplotlib):
pip install mosaic-harmonize[full]
Individual extras: boosting, conformal, viz. For development: dev.
Requires Python 3.9+.
Quick start
Full pipeline
from mosaic import MOSAICPipeline
from lightgbm import LGBMRegressor
pipe = MOSAICPipeline(
harmonizer="ot",
robust_learner="anchor",
uncertainty="weighted_conformal",
base_estimator=LGBMRegressor(),
)
# center_ids: array of site labels, one per row
pipe.fit(X_train, y_train, center_ids=train_centers)
result = pipe.predict(X_test, center_id="new_hospital")
print(result.prediction) # point predictions
print(result.lower, result.upper) # 90% prediction intervals
Individual components
from mosaic import OTHarmonizer, AnchorEstimator, ConformalCalibrator
# Tier 1: align distributions
ot = OTHarmonizer(n_quantiles=1000, reference="global")
X_harmonized = ot.fit_transform(X_train, center_ids=train_centers)
# Inspect shift reduction
print(ot.wasserstein_distances())
print(ot.feature_shift_report())
# Tier 2: train a domain-robust model
anchor = AnchorEstimator(base_estimator=LGBMRegressor(), task_type="regression")
anchor.fit(X_harmonized, y_train, anchors=train_centers)
print(f"Best gamma: {anchor.best_gamma_}")
print(f"Cross-center stability: {anchor.stability_score_:.3f}")
# Tier 3: calibrate with conformal prediction
cal = ConformalCalibrator(method="weighted", alpha=0.10)
cal.calibrate(anchor, X_cal, y_cal, X_test=X_test)
result = cal.predict(X_test)
print(f"Interval widths: {(result.upper - result.lower).mean():.2f}")
Save and load
pipe.save("model.mosaic")
pipe = MOSAICPipeline.load("model.mosaic")
Register a new center at inference time
pipe.register_center("hospital_B", X_new_center)
result = pipe.predict(X_query, center_id="hospital_B")
API reference
OTHarmonizer
| Parameter | Type | Default | Description |
|---|---|---|---|
n_quantiles |
int | 1000 | Number of quantile points for the OT map |
features |
list[str] or None | None | Columns to harmonize (None = all numeric) |
reference |
str | "global" | Reference distribution: "global" or a center name |
min_samples |
int | 50 | Minimum non-null samples to build a map |
Methods: fit(X, center_ids), transform(X, center_id=..., center_ids=...), fit_transform(X, center_ids), wasserstein_distances(), feature_shift_report().
AnchorEstimator
| Parameter | Type | Default | Description |
|---|---|---|---|
base_estimator |
sklearn estimator or None | None | Base learner (None = Ridge/LogisticRegression) |
gammas |
list[float] or None | [1.5, 3.0, 7.0] | Anchor penalty strengths to search |
task_type |
str | "auto" | "auto", "regression", "binary", or "multiclass" |
n_vrex_rounds |
int | 5 | V-REx reweighting iterations (fallback mode) |
Methods: fit(X, y, anchors, X_val=None, y_val=None), predict(X), predict_proba(X). Properties: best_gamma_, stability_score_.
ConformalCalibrator
| Parameter | Type | Default | Description |
|---|---|---|---|
method |
str | "weighted" | "weighted", "standard", or "lac" |
alpha |
float | 0.10 | Miscoverage level (0.10 = 90% target coverage) |
Methods: calibrate(model, X_cal, y_cal, X_test=None), predict(X_test) returning ConformalResult.
MOSAICPipeline
| Parameter | Type | Default | Description |
|---|---|---|---|
harmonizer |
str or None | "ot" | "ot" or None |
robust_learner |
str or None | "anchor" | "anchor" or None |
uncertainty |
str or None | "weighted_conformal" | "weighted_conformal", "standard", "lac", or None |
base_estimator |
sklearn estimator | None | Base learner passed to AnchorEstimator |
Methods: fit(X_train, y_train, center_ids, X_cal=None, y_cal=None), predict(X, center_id=..., center_ids=...), register_center(name, X_new), diagnose(X, center_id), save(path), load(path).
Benchmarks
Evaluated on a multi-center IVF dataset (334K rows, 5 centers, 15 prediction targets). Full results in benchmarks/results/.
Ablation (Exp 1): each tier adds value
| Target | Baseline R² | +OT | +OT+Anchor | Full MOSAIC |
|---|---|---|---|---|
| HCG_Day_E2 | -0.665 | 0.127 | 0.229 | 0.229 |
| egg_num | 0.210 | 0.205 | 0.352 | 0.352 |
| HCG_Day_Endo | -0.775 | -0.177 | 0.120 | 0.120 |
OT corrects distribution shift (E2: R² from -0.67 to 0.13). Anchor regression adds further gains for regression targets (egg_num: 0.21 to 0.35).
Cross-center generalization gap (Exp 2)
On an external test center unseen during training, MOSAIC reduces the validation-to-test performance gap by 42-76% for high-shift features (HCG_Day_E2: 75%, HCG_Day_P: 52%, HCG_Day_Endo: 72%).
Conformal coverage (Exp 3)
All 11 regression targets achieve 81-93% empirical coverage at the 90% nominal level. Weighted conformal consistently produces narrower intervals than standard split conformal at comparable coverage.
Comparison with existing methods (Exp 5)
| Feature | No harmonization | Z-score | ComBat | MOSAIC (OT) |
|---|---|---|---|---|
| HCG_Day_E2 (R²) | -0.665 | 0.069 | -0.383 | 0.127 |
| Clinical_pregnancy (AUC) | 0.838 | 0.837 | 0.837 | 0.839 |
| total_Gn (R²) | -0.121 | -0.327 | -0.034 | -0.022 |
MOSAIC outperforms Z-score and ComBat on high-shift features while maintaining comparable performance on low-shift targets.
Citation
If you use MOSAIC in your research, please cite:
@article{chen2026mosaic,
title={MOSAIC: Multi-site Optimal-transport Shift Alignment with Interval
Calibration for Clinical Data Harmonization},
author={Chen, Peigen},
journal={npj Digital Medicine},
year={2026},
note={Manuscript in preparation}
}
License
Apache-2.0. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mosaic_harmonize-0.1.0.tar.gz.
File metadata
- Download URL: mosaic_harmonize-0.1.0.tar.gz
- Upload date:
- Size: 27.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a1f90ba0f81c13728a4443400594025861a0bcbfd660bfb639d51eafb2298148
|
|
| MD5 |
807ae8dcd8ad5ff9650826fcf4a2286f
|
|
| BLAKE2b-256 |
26bb74515ae3cee0e96fba5492ce18ad3866d7aac6c4a73752b81e084d427211
|
File details
Details for the file mosaic_harmonize-0.1.0-py3-none-any.whl.
File metadata
- Download URL: mosaic_harmonize-0.1.0-py3-none-any.whl
- Upload date:
- Size: 22.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d7c2bca4790c097903271d30a9f2ffa93a73b048fdc0a1135853b871196e681f
|
|
| MD5 |
890172cb6045607defbfb51677c447cd
|
|
| BLAKE2b-256 |
398b5579531ded477472d7268988574d97d5b53a567bbe8b3f52d72ef36a15a7
|