Skip to main content

Cross-framework causal-ML ensembles for ATE and CATE with full-pipeline bootstrap inference

Project description

MetaCausal

Cross-framework ensembling of causal machine-learning estimators for ATE and pointwise CATE, with full-pipeline bootstrap inference.

PyPI version Python versions License: MIT

What it is

MetaCausal orchestrates multiple causal-ML estimators from different libraries — EconML, DoubleML, CausalML, stochtree, or arbitrary user-supplied callables — behind a single protocol, and aggregates their treatment-effect estimates into a single ensemble estimate. Seven aggregation strategies are provided, grouped into three tiers:

  • Pointwise robust — Median (default), Mean, Trimmed Mean.
  • Agreement-based — Consensus Based Averaging, which selects a high-agreement subset of components from pairwise Kendall's τ.
  • Outcome-supervised — Causal Stacking, R-Stacking, Q-Aggregation, which learn weights by optimising a causal loss on cross-fitted out-of-fold predictions.

A full-pipeline bootstrap supplies comparable confidence intervals for both ATE and pointwise CATE across heterogeneous components whose native inference machinery is otherwise incomparable.

Why

No single causal-ML estimator dominates across data-generating processes, model selection for heterogeneous treatment effects is empirically unreliable, and individual methods can fail catastrophically under specific violations of their own assumptions (overlap breakdown, nuisance misspecification, tree extrapolation). MetaCausal's default pointwise median aggregation gives a 50% breakdown point with no tuning — up to half the component estimators can produce arbitrarily bad estimates without corrupting the ensemble. When outcome data allow learning weights, MetaCausal also ships the three outcome-supervised stackers from the recent CATE-ensemble literature.

Installation

pip install metacausal

This installs the core package and its required dependencies (numpy, pandas, scipy). Estimator libraries are optional extras:

# Individual libraries
pip install "metacausal[econml]"
pip install "metacausal[doubleml]"
pip install "metacausal[causalml]"
pip install "metacausal[stochtree]"

# Visualisation helpers (matplotlib)
pip install "metacausal[plots]"

# Everything (frameworks + plots)
pip install "metacausal[all]"

Python 3.11 or later is required.

Quick start

from metacausal import CausalEnsemble
from metacausal.datasets import load_lalonde

X, T, Y = load_lalonde()

# Default ensemble: ten estimators spanning EconML, DoubleML,
# CausalML, and stochtree, aggregated by pointwise median.
ens = CausalEnsemble()
ens.fit(X, T, Y, random_state=42)

# Point estimate
ate = ens.ate()
print(f"Ensemble ATE: {ate.ate:.1f}")
for name, est in ate.component_estimates.items():
    print(f"  {name:<25} {est.ate:>9.1f}")

# Honest bootstrap confidence interval
boot = ens.bootstrap(n_boot=200, random_state=42, n_jobs=-1)
print(f"95% CI: [{boot.ate_ci_lower:.1f}, {boot.ate_ci_upper:.1f}]")

The three-step fit → ate / cate → bootstrap pattern is the recommended one, because it lets you inspect intermediate state and swap aggregation strategies on an already-fitted ensemble. The convenience wrapper ens.estimate(X, T, Y, n_boot=200, ...) does fit + bootstrap (or fit + ate) in a single call.

Aggregation strategies at a glance

Tier Strategy String alias / class Data used
Pointwise Median (default) "median" / Median Component predictions only
Pointwise Mean "mean" / Mean Component predictions only
Pointwise Trimmed Mean "trimmed_mean" / TrimmedMean Component predictions only
Agreement Consensus Based Averaging "cba" / CBA Component CATE predictions on training data
Supervised Causal Stacking CausalStacking Cross-fitted OOF predictions + nuisance
Supervised R-Stacking RStacking Cross-fitted OOF predictions + nuisance
Supervised Q-Aggregation QAggregation Cross-fitted OOF predictions + nuisance
# By string alias (default configuration)
ens = CausalEnsemble(aggregation="trimmed_mean")

# By object (lets you configure hyperparameters)
from metacausal.aggregation import QAggregation
ens = CausalEnsemble(aggregation=QAggregation(nu=0.5, greedy=True))

See the accompanying paper (forthcoming) for the mathematical details of each strategy.

Outcome types

MetaCausal supports two outcome types:

  • Continuous (default) — any numeric Y not detected as binary. Quietly absorbs counts, bounded continuous, and ordinal-as-numeric. The base learner choice is the user's responsibility (an HistGradientBoostingRegressor by default; user-supplied components can pass a Poisson booster if appropriate).
  • Binary — numeric Y with values ⊆ {0, 1} or boolean dtype. The estimand is the risk difference ATE/CATE (mean difference of probabilities).

Detection happens at fit time: CausalEnsemble().fit(X, T, Y) inspects Y, picks the right pool from default_methods, and routes nuisance estimation through predict_proba for binary outcomes. To force an interpretation, pass outcome_type="continuous" or outcome_type="binary" at construction. Multi-class / nominal and survival outcomes are out of scope; encoding-as-multiple-binary or a dedicated survival library is the recommended path.

The default binary pool (8 components) drops DoubleMLPLR, EconML S/T/X-Learners, the BaseRRegressor, and stochtree BCF — all of which either lack a binary-capable code path in their upstream library or would silently fit a linear-probability model — and substitutes the CausalML classifier siblings (BaseSClassifier, BaseTClassifier, BaseXClassifier, BaseRClassifier). DoubleMLIRM, CausalForestDML, DRLearner, and TMLELearner remain.

Usage recipes

Mixed-framework method list

Estimators from EconML and CausalML are auto-detected by module prefix; DoubleML, stochtree, and arbitrary callables go through explicit adapters.

from metacausal import CausalEnsemble, GenericATEAdapter
from metacausal.adapters import DoubleMLAdapter, CausalMLAdapter
from econml.dml import CausalForestDML
from econml.metalearners import TLearner, XLearner
from doubleml import DoubleMLIRM
from causalml.inference.meta import BaseDRRegressor
from sklearn.ensemble import (
    HistGradientBoostingRegressor as HGBR,
    HistGradientBoostingClassifier as HGBC,
)

def naive_diff(X, T, Y):
    return float(Y[T == 1].mean() - Y[T == 0].mean())

ens = CausalEnsemble(
    methods=[
        CausalForestDML(discrete_treatment=True),              # auto-wrapped (EconML)
        TLearner(models=HGBR()),                               # auto-wrapped (EconML)
        XLearner(models=HGBR(), propensity_model=HGBC()),      # auto-wrapped (EconML)
        DoubleMLAdapter(DoubleMLIRM, ml_g=HGBR(), ml_m=HGBC()),
        CausalMLAdapter(BaseDRRegressor(learner=HGBR())),
        GenericATEAdapter(naive_diff, name="naive_diff"),
    ],
    aggregation="median",
)
ens.fit(X, T, Y, random_state=42)
print(ens.ate().ate)

To configure analytical upstream inference, wrap the estimator explicitly instead of relying on auto-detection:

from causalml.inference.meta import BaseTRegressor
from doubleml import DoubleMLIRM
from econml.dml import CausalForestDML
from metacausal.adapters import DoubleMLAdapter, EconMLAdapter, CausalMLAdapter, StochtreeAdapter
from sklearn.ensemble import (
    HistGradientBoostingRegressor as HGBR,
    HistGradientBoostingClassifier as HGBC,
)

dml = DoubleMLAdapter(DoubleMLIRM, ml_g=HGBR(), ml_m=HGBC(), alpha=0.10)
econ = EconMLAdapter(CausalForestDML(model_y=HGBR(), model_t=HGBC(), discrete_treatment=True), alpha=0.10, inference="statsmodels")
cml = CausalMLAdapter(BaseTRegressor(learner=HGBR(), ate_alpha=0.10))  # upstream CausalML control
st = StochtreeAdapter(alpha=0.10)

For CausalML, analytical ATE CI settings stay on the wrapped upstream estimator (ate_alpha on meta-learners / TMLE, alpha on CausalTreeRegressor); the adapter does not override them.

Binary outcome on real data

load_lalonde(binarize_y=...) returns the 1978-earnings outcome as a binary indicator — "median" for the (~50/50) above-median split, "positive" for the (~69/31) "any 1978 earnings" indicator. Useful as a real-data fixture without leaving the package.

from metacausal import CausalEnsemble
from metacausal.datasets import load_lalonde

X, T, Y = load_lalonde(binarize_y="median")

# outcome_type="auto" detects binary Y, materialises the binary
# default pool (8 components targeting the risk difference), and
# fits. ATE is on the risk-difference scale, in [-1, 1].
ens = CausalEnsemble()
ens.fit(X, T, Y, random_state=42)

print(ens.ate().ate)

CATE estimation with a supervised strategy

from metacausal import CausalEnsemble
from metacausal.aggregation import CausalStacking

ens = CausalEnsemble(aggregation=CausalStacking())
ens.fit(X, T, Y, random_state=42)

# Pointwise CATE CIs on a held-out grid
boot = ens.bootstrap(X_eval, n_boot=200, random_state=42, n_jobs=-1)

print(boot.cate)           # ensemble CATE at X_eval, shape (n_eval,)
print(boot.cate_ci_lower)  # pointwise 95% lower bound
print(boot.cate_ci_upper)  # pointwise 95% upper bound

# Inspect the learned ensemble weights
for name, w in zip(boot.ensemble_weights.model_names,
                   boot.ensemble_weights.weights):
    print(f"  {name:<25} {w:>6.3f}")

Compare aggregation strategies without refitting

An aggregation=... argument to ate() or cate() re-aggregates from cached predictions without refitting components — useful for quick comparisons.

ens = CausalEnsemble(aggregation="median")
ens.fit(X, T, Y, random_state=42)

for agg in ["median", "mean", "trimmed_mean", "cba"]:
    ate = ens.ate(aggregation=agg)
    print(f"{agg:<15} ATE = {ate.ate:.1f}")

Visualisation helpers

The optional metacausal.plots submodule (installed via the [plots] extra) provides four matplotlib helpers that consume the result types above:

  • forest(boot) — component and ensemble ATEs with bootstrap CIs.
  • weights(ens) — aggregation weight bars (agreement-based and supervised strategies).
  • cate_profile(source, x, xlabel=...) — ensemble CATE along one covariate, with optional bootstrap band and per-component overlay.
  • disagreement(ens, X) — pairwise component-CATE rank-correlation heatmap.
from metacausal.plots import forest, cate_profile

forest(boot)
cate_profile(boot, x=grid, xlabel="re74 (1974 earnings, USD)")

Extending MetaCausal

MetaCausal exposes five injection points that let researchers extend the package without forking it: custom component adapters, custom aggregation strategies, replacement nuisance pipelines (fit_nuisance_fn), replacement pseudo-outcome functions (pseudo_outcome_fn), and custom cross-fitting splitters. The accompanying paper (forthcoming) covers each injection point in detail.

The lowest-effort path for adding a new estimator is GenericCATEAdapter, which wraps a fit function, a CATE prediction function, and (optionally) an ATE prediction function into a component without implementing the full protocol:

from metacausal import CausalEnsemble, GenericCATEAdapter

def fit_fn(X, T, Y, **kwargs):
    # Train your model and return any state you need.
    ...
    return state

def cate_fn(state, X):
    # Return per-observation CATE estimates, shape (n,).
    return state.predict_cate(X)

def ate_fn(state, X):  # optional; defaults to mean of cate_fn(state, X)
    return float(cate_fn(state, X).mean())

my_method = GenericCATEAdapter(
    fit_fn, cate_fn, fn_ate=ate_fn, name="my_method",
)

ens = CausalEnsemble(methods=[my_method, ...])

Reproducibility and parallelism

A single random_state seed deterministically propagates to every stochastic sub-step — component models, their sub-estimators, cross-fitting folds, nuisance fits, and bootstrap replicates — so reruns are bit-identical.

A single n_jobs knob on fit, bootstrap, and estimate routes parallelism to the outermost applicable level (bootstrap replicates when n_boot > 0; otherwise supervised cross-fitting or component fits) and pins BLAS/OpenMP threads inside each worker to prevent oversubscription. The accompanying paper (forthcoming) explains the rationale.

The outer process (your main script) keeps the platform-default BLAS thread count, which is fine on macOS and Windows. On Linux, where joblib's loky backend can occasionally deadlock at fork time when the parent's BLAS pool is already running threads, defensive users may want to set the standard thread env vars (OMP_NUM_THREADS=1, OPENBLAS_NUM_THREADS=1, MKL_NUM_THREADS=1, NUMEXPR_NUM_THREADS=1, VECLIB_MAXIMUM_THREADS=1) before invoking Python. The bundled replication runner and the test suite's tests/conftest.py set these automatically, so reviewers and contributors do not need the shell prefix.

# Parallelise supervised cross-fitting, deterministic:
ens = CausalEnsemble(aggregation=CausalStacking())
ens.fit(X, T, Y, random_state=42, n_jobs=-1)

# Or: full fit + bootstrap pipeline with bootstrap-level parallelism:
boot = ens.estimate(X, T, Y, n_boot=500, random_state=42, n_jobs=-1)

Citation

A BibTeX entry will be added here when the arXiv preprint of the accompanying manuscript is posted. For interim references to the software itself, see the PyPI listing.

Further reading

  • Paper: a preprint covering the methodology, architecture, and extensibility hooks is in preparation. An arXiv link will be added here once it is posted.
  • Replication material: will be included as ancillary files with the forthcoming arXiv submission.

Release notes

0.5.0 — 2026-06-15

  • bootstrap() flags a point estimate outside its CI under both resampling schemes. The BootstrapWarning for an ATE confidence interval that does not contain the point estimate previously fired only under the subsample scheme; it now fires under nonparametric too, with a scheme-specific explanation. Containment is governed by the same percentile condition for both schemes — the √(m/n) scaling changes only the interval width — and with-replacement resampling (~63% distinct units) or per-replicate weight re-optimization can shift the replicate distribution off the point estimate. Pre-1.0 behavior change: nonparametric bootstraps that previously ran silently may now emit a warning.
  • Single-valued outcomes are rejected with an actionable error. infer_outcome_type previously classified a constant Y (all zeros, all ones, or any single repeated value) as binary, which surfaced later as an opaque scikit-learn "needs samples of at least 2 classes" failure. It now raises a clear ValueError up front. Pre-1.0 behavior change.
  • Docstring corrections. The supervised aggregation strategies (CausalStacking, RStacking, QAggregation) documented a stale LGBMRegressor/LGBMClassifier default and omitted the regressor-vs-classifier requirement for binary outcomes; they now describe the HistGradientBoosting* defaults and the outcome-type-dependent model contract. GenericCATEAdapter no longer references a non-existent extensibility guide, and its fn_ate example now performs a genuine doubly robust (AIPW) computation instead of silently reproducing the default mean(cate(X)).
  • Correction to the 0.4.0 parallel-safety note. The intermittent segmentation faults in EconML's Cython tree builder (CausalForestDML) are a latent out-of-bounds bug in EconML's generalized random forest — not merely loky oversubscription. They reproduce single-threaded, surfacing under the repeated component refits a bootstrap performs (EconML #470, unresolved as of econml 0.16.0). The 0.4.0 n_jobs=1 pin is kept as defensive hardening (it removes genuine oversubscription and lowers crash probability) but does not fully eliminate the crash: bootstrap() with CausalForestDML in the pool can still segfault intermittently under heavy parallelism.

0.4.0 — 2026-06-07

  • Parallel-safety fix: the default CausalForestDML now builds its forest with n_jobs=1, and EconMLAdapter pins any wrapped joblib-parallel estimator to a single job when it runs inside one of MetaCausal's own parallel workers (fit, bootstrap, or supervised cross-fitting). EconML defaults CausalForestDML to n_jobs=-1, whose inner loky pool would otherwise nest inside the outer worker and intermittently segfault EconML's Cython tree builder under oversubscription. Outputs are unchanged — the forest is seeded deterministically — only the nested parallelism is removed.
  • Default pool update: S-Learner and its classifier sibling (BaseSClassifier) are now excluded from the default pools for continuous and binary outcomes, respectively; the default pools are now nine and seven components.

0.3.1 — 2026-06-05

  • Dependency declaration: joblib>=1.2 is now a direct core dependency instead of arriving only transitively through scikit-learn.
  • Top-level adapter imports: DoubleMLAdapter, EconMLAdapter, and StochtreeAdapter are now exported from metacausal, so users no longer need submodule import paths for those adapters.
  • Custom CATE shorthand: CausalEnsemble now auto-wraps 2- and 3-callable lists/tuples as GenericCATEAdapter, allowing (fn_fit, fn_cate) and (fn_fit, fn_cate, fn_ate) directly in methods=[...].

0.3.0 — 2026-05-30

  • Analytical CI controls: DoubleMLAdapter(alpha=...) now forwards to DoubleML.confint(level=1 - alpha), and EconMLAdapter(alpha=..., inference=...) now forwards both the interval level and the fit-time inference backend.
  • CausalML clarification: analytical ATE CI level was already configurable upstream via the wrapped estimator (ate_alpha / estimator-specific alpha); the docs now describe that accurately instead of implying a missing adapter feature.
  • API cleanup: removed the unused alpha argument from CausalEnsemble.cate(). Analytical component-level CI configuration now lives on the relevant adapter or wrapped upstream estimator, while ensemble CI level remains on bootstrap(alpha=...).

0.2.2 — 2026-05-04

  • Dependency hygiene: the four causal-ML extras (econml, doubleml, stochtree, causalml) are now pinned to a single patch each — floors at the exact version validated by CI on the most recent main-branch run, caps at the next patch. Motivated by stochtree #376, where a patch release (0.4.0 → 0.4.2) silently changed the semantics of BCFModel.predict(terms="tau") and broke 0.2.0. Patch-level caps mean every upstream release lands outside the cap, triggers a Dependabot PR, and runs pytest -m integration before we widen — closing the silent-install hole that produced the 0.2.1 hotfix.

0.2.1 — 2026-05-04

  • Bug fix: StochtreeAdapter now calls BCFModel.predict(..., terms="cate") instead of terms="tau". With stochtree 0.4.2 (which added a parametric treatment-intercept term in the BCF sampler), terms="tau" returned the forest-only piece and excluded the parametric component, producing wildly seed-sensitive ATEs that disagreed sharply with the rest of the default ensemble. terms="cate" returns the full conditional treatment effect — including parametric and random-slope components, when present — for any BCF configuration. Fixes upstream issue stochtree #376 on the metacausal side.

0.2.0 — 2026-05-04

New

  • Outcome-type handling: CausalEnsemble auto-detects continuous vs binary Y at fit(), materialises the right default pool, and routes nuisance through predict_proba for binary. Override via outcome_type="continuous"|"binary". Public metacausal.infer_outcome_type(Y) utility; binarize_y={"median","positive"} on load_lalonde().
  • Subsample bootstrap (bootstrap(method="subsample")): m-out-of-n without replacement, T-stratified, with Politis–Romano scaled-percentile CIs. Eliminates duplicate-unit leakage across cross-fit folds.
  • Structured warning hierarchy: ComponentFailureWarning, ComponentExclusionWarning, BootstrapWarning under a common MetaCausalWarning umbrella.
  • CausalMLAdapter accepts a propensity_model= kwarg, forwarding a fitted propensity to non-TMLE meta-learners.

Breaking changes for custom-strategy / custom-adapter authors

  • AggregationStrategy and family are abc.ABC with a unified aggregate entry point. Subclasses now implement aggregate rather than per-mode methods.
  • Every adapter must declare supported_outcome_types and implement validate_outcome_type(detected); the injectable fit_nuisance_fn gains an outcome_type parameter.

Other

  • Bounded version constraints on econml, doubleml, causalml, stochtree (capped at next minor; floors anchored to tested versions).
  • requires-python raised to >=3.11 (causalml 0.16 floor).
  • Tier-2 integration tests via pytest -m integration.
  • Bug fixes: load_lalonde no longer leaks a file handle; EconMLAdapter suppresses the upstream DataConversionWarning from DRLearner(discrete_outcome=True).
  • PyPI metadata polish (classifiers, license badge).

Tested against: doubleml 0.11.2, econml 0.16.0, causalml 0.16.0, stochtree 0.4.0.

0.1.0 — 2026-04-25

Initial public release.

License

MetaCausal is distributed under the MIT License. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

metacausal-0.5.0.tar.gz (146.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

metacausal-0.5.0-py3-none-any.whl (92.1 kB view details)

Uploaded Python 3

File details

Details for the file metacausal-0.5.0.tar.gz.

File metadata

  • Download URL: metacausal-0.5.0.tar.gz
  • Upload date:
  • Size: 146.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for metacausal-0.5.0.tar.gz
Algorithm Hash digest
SHA256 e2376a0ff53889dd80b81f74a1a88fdaebe19c6b9d9e8c6d81abfc8893ca5268
MD5 fb16b597760664746e3591d93a063a56
BLAKE2b-256 1103351611be3262593bdc4b0dde4611dbb7c040a7c8920f283ca3317695b4a2

See more details on using hashes here.

Provenance

The following attestation bundles were made for metacausal-0.5.0.tar.gz:

Publisher: publish.yml on asmahani/metacausal

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file metacausal-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: metacausal-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 92.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for metacausal-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 45aac4a753de41f6f5cea4e82e75215e4146bf4f5f361704aee29bc59d17a194
MD5 37a6addc6bd2f61cc51a88d7e6ebe8c6
BLAKE2b-256 3b04df48f30287f1ad6f5cc35ca7b7e13c2b9f6e5c335c5698e8c6b27322acc0

See more details on using hashes here.

Provenance

The following attestation bundles were made for metacausal-0.5.0-py3-none-any.whl:

Publisher: publish.yml on asmahani/metacausal

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page