Skip to main content

Cross-framework causal-ML ensembles for ATE and CATE with honest bootstrap inference

Project description

MetaCausal

Cross-framework ensembling of causal machine-learning estimators for ATE and pointwise CATE, with honest bootstrap inference.

PyPI version Python versions License

What it is

MetaCausal orchestrates multiple causal-ML estimators from different libraries — EconML, DoubleML, CausalML, stochtree, or arbitrary user-supplied callables — behind a single protocol, and aggregates their treatment-effect estimates into a single ensemble estimate. Seven aggregation strategies are provided, grouped into three tiers:

  • Pointwise robust — Median (default), Mean, Trimmed Mean.
  • Agreement-based — Consensus Based Averaging, which selects a high-agreement subset of components from pairwise Kendall's τ.
  • Outcome-supervised — Causal Stacking, R-Stacking, Q-Aggregation, which learn weights by optimising a causal loss on cross-fitted out-of-fold predictions.

A full-pipeline honest bootstrap supplies comparable confidence intervals for both ATE and pointwise CATE across heterogeneous components whose native inference machinery is otherwise incomparable.

Why

No single causal-ML estimator dominates across data-generating processes, model selection for heterogeneous treatment effects is empirically unreliable, and individual methods can fail catastrophically under specific violations of their own assumptions (overlap breakdown, nuisance misspecification, tree extrapolation). MetaCausal's default pointwise median aggregation gives a 50% breakdown point with no tuning — up to half the component estimators can produce arbitrarily bad estimates without corrupting the ensemble. When outcome data allow learning weights, MetaCausal also ships the three outcome-supervised stackers from the recent CATE-ensemble literature.

Installation

pip install metacausal

This installs the core package and its required dependencies (numpy, pandas, scipy). Estimator libraries are optional extras:

# Individual libraries
pip install "metacausal[econml]"
pip install "metacausal[doubleml]"
pip install "metacausal[causalml]"
pip install "metacausal[stochtree]"

# Visualisation helpers (matplotlib)
pip install "metacausal[plots]"

# Everything (frameworks + plots)
pip install "metacausal[all]"

Python 3.10 or later is required.

Quick start

from metacausal import CausalEnsemble
from metacausal.datasets import load_lalonde

X, T, Y = load_lalonde()

# Default ensemble: ten estimators spanning EconML, DoubleML,
# CausalML, and stochtree, aggregated by pointwise median.
ens = CausalEnsemble()
ens.fit(X, T, Y, random_state=42)

# Point estimate
ate = ens.ate()
print(f"Ensemble ATE: {ate.ate:.1f}")
for name, est in ate.component_estimates.items():
    print(f"  {name:<25} {est.ate:>9.1f}")

# Honest bootstrap confidence interval
boot = ens.bootstrap(n_boot=200, random_state=42, n_jobs=-1)
print(f"95% CI: [{boot.ate_ci_lower:.1f}, {boot.ate_ci_upper:.1f}]")

The three-step fit → ate / cate → bootstrap pattern is the recommended one, because it lets you inspect intermediate state and swap aggregation strategies on an already-fitted ensemble. The convenience wrapper ens.estimate(X, T, Y, n_boot=200, ...) does fit + bootstrap (or fit + ate) in a single call.

Aggregation strategies at a glance

Tier Strategy String alias / class Data used
Pointwise Median (default) "median" / Median Component predictions only
Pointwise Mean "mean" / Mean Component predictions only
Pointwise Trimmed Mean "trimmed_mean" / TrimmedMean Component predictions only
Agreement Consensus Based Averaging "cba" / CBA Component CATE predictions on training data
Supervised Causal Stacking CausalStacking Cross-fitted OOF predictions + nuisance
Supervised R-Stacking RStacking Cross-fitted OOF predictions + nuisance
Supervised Q-Aggregation QAggregation Cross-fitted OOF predictions + nuisance
# By string alias (default configuration)
ens = CausalEnsemble(aggregation="trimmed_mean")

# By object (lets you configure hyperparameters)
from metacausal.aggregation import QAggregation
ens = CausalEnsemble(aggregation=QAggregation(nu=0.5, greedy=True))

See the paper §2 for the mathematical details of each strategy.

Usage recipes

Mixed-framework method list

Estimators from EconML and CausalML are auto-detected by module prefix; DoubleML, stochtree, and arbitrary callables go through explicit adapters.

from metacausal import CausalEnsemble, GenericATEAdapter
from metacausal.adapters import DoubleMLAdapter, CausalMLAdapter
from econml.dml import CausalForestDML
from econml.metalearners import TLearner, XLearner
from doubleml import DoubleMLIRM
from causalml.inference.meta import BaseDRRegressor
from sklearn.ensemble import (
    HistGradientBoostingRegressor as HGBR,
    HistGradientBoostingClassifier as HGBC,
)

def naive_diff(X, T, Y):
    return float(Y[T == 1].mean() - Y[T == 0].mean())

ens = CausalEnsemble(
    methods=[
        CausalForestDML(discrete_treatment=True),              # auto-wrapped (EconML)
        TLearner(models=HGBR()),                               # auto-wrapped (EconML)
        XLearner(models=HGBR(), propensity_model=HGBC()),      # auto-wrapped (EconML)
        DoubleMLAdapter(DoubleMLIRM, ml_g=HGBR(), ml_m=HGBC()),
        CausalMLAdapter(BaseDRRegressor(learner=HGBR())),
        GenericATEAdapter(naive_diff, name="naive_diff"),
    ],
    aggregation="median",
)
ens.fit(X, T, Y, random_state=42)
print(ens.ate().ate)

CATE estimation with a supervised strategy

from metacausal import CausalEnsemble
from metacausal.aggregation import CausalStacking

ens = CausalEnsemble(aggregation=CausalStacking())
ens.fit(X, T, Y, random_state=42)

# Pointwise CATE CIs on a held-out grid
boot = ens.bootstrap(X_eval, n_boot=200, random_state=42, n_jobs=-1)

print(boot.cate)           # ensemble CATE at X_eval, shape (n_eval,)
print(boot.cate_ci_lower)  # pointwise 95% lower bound
print(boot.cate_ci_upper)  # pointwise 95% upper bound

# Inspect the learned ensemble weights
for name, w in zip(boot.ensemble_weights.model_names,
                   boot.ensemble_weights.weights):
    print(f"  {name:<25} {w:>6.3f}")

Compare aggregation strategies without refitting

An aggregation=... argument to ate() or cate() re-aggregates from cached predictions without refitting components — useful for quick comparisons.

ens = CausalEnsemble(aggregation="median")
ens.fit(X, T, Y, random_state=42)

for agg in ["median", "mean", "trimmed_mean", "cba"]:
    ate = ens.ate(aggregation=agg)
    print(f"{agg:<15} ATE = {ate.ate:.1f}")

Visualisation helpers

The optional metacausal.plots submodule (installed via the [plots] extra) provides four matplotlib helpers that consume the result types above:

  • forest(boot) — component and ensemble ATEs with bootstrap CIs.
  • weights(ens) — aggregation weight bars (agreement-based and supervised strategies).
  • cate_profile(source, x, xlabel=...) — ensemble CATE along one covariate, with optional bootstrap band and per-component overlay.
  • disagreement(ens, X) — pairwise component-CATE rank-correlation heatmap.
from metacausal.plots import forest, cate_profile

forest(boot)
cate_profile(boot, x=grid, xlabel="re74 (1974 earnings, USD)")

Extending MetaCausal

MetaCausal exposes five injection points that let researchers extend the package without forking it: custom component adapters, custom aggregation strategies, replacement nuisance pipelines (fit_nuisance_fn), replacement pseudo-outcome functions (pseudo_outcome_fn), and custom cross-fitting splitters. Full details are in paper §3.6.

The lowest-effort path for adding a new estimator is GenericCATEAdapter, which wraps a fit function, a CATE prediction function, and (optionally) an ATE prediction function into a component without implementing the full protocol:

from metacausal import CausalEnsemble, GenericCATEAdapter

def fit_fn(X, T, Y, **kwargs):
    # Train your model and return any state you need.
    ...
    return state

def cate_fn(state, X):
    # Return per-observation CATE estimates, shape (n,).
    return state.predict_cate(X)

def ate_fn(state, X):  # optional; defaults to mean of cate_fn(state, X)
    return float(cate_fn(state, X).mean())

my_method = GenericCATEAdapter(
    fit_fn, cate_fn, fn_ate=ate_fn, name="my_method",
)

ens = CausalEnsemble(methods=[my_method, ...])

Reproducibility and parallelism

A single random_state seed deterministically propagates to every stochastic sub-step — component models, their sub-estimators, cross-fitting folds, nuisance fits, and bootstrap replicates — so reruns are bit-identical.

A single n_jobs knob on fit, bootstrap, and estimate routes parallelism to the outermost applicable level (bootstrap replicates when n_boot > 0; otherwise supervised cross-fitting or component fits) and pins BLAS/OpenMP threads inside each worker to prevent oversubscription. See paper §3.4 for the rationale.

# Parallelise supervised cross-fitting, deterministic:
ens = CausalEnsemble(aggregation=CausalStacking())
ens.fit(X, T, Y, random_state=42, n_jobs=-1)

# Or: full fit + bootstrap pipeline with bootstrap-level parallelism:
boot = ens.estimate(X, T, Y, n_boot=500, random_state=42, n_jobs=-1)

Citation

A BibTeX entry will be added here when the arXiv preprint of the accompanying manuscript is posted. For interim references to the software itself, see the PyPI listing.

Further reading

  • Paper (open access): arXiv link — to be added when the preprint is posted. The paper covers the mathematical details of each aggregation strategy (§2), the package architecture (§3), and the extensibility hooks (§3.6).
  • Replication material: will be included as ancillary files with the forthcoming arXiv submission.

Release notes

0.1.0 — 2026-04-25

Initial public release.

License

MetaCausal is distributed under the MIT License. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

metacausal-0.1.0.tar.gz (102.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

metacausal-0.1.0-py3-none-any.whl (71.8 kB view details)

Uploaded Python 3

File details

Details for the file metacausal-0.1.0.tar.gz.

File metadata

  • Download URL: metacausal-0.1.0.tar.gz
  • Upload date:
  • Size: 102.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for metacausal-0.1.0.tar.gz
Algorithm Hash digest
SHA256 09d3aebce2c82c9e322725b2f69f778d671abf33c8a42d318cdc63a480a7be7f
MD5 20eb36853d99343c847ca2d664752274
BLAKE2b-256 8e2be878c09ceeeda0835ab889c31f3abf50e1092d3aa7e29222b8ed0a7aaa44

See more details on using hashes here.

File details

Details for the file metacausal-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: metacausal-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 71.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for metacausal-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 93c8b037cc81358fa047c09cbb2c387838180cc6b2b120c79a97ba55ad1b9065
MD5 bbffe75f26254a0e93f722d37f86a5ee
BLAKE2b-256 674afb362823b41e6d84e676385d0e320aea9ebcc2abbb05ba583a79053edc8a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page