Causal Inference using Ensemble Matching
Project description
CausalEM – Ensemble Matching for Causal Inference
CausalEM is a toolbox for multi-arm treatment‑effect estimation and mediation analysis using stochastic matching and a stacked ensemble of heterogeneous ML models. It supports continuous, binary, and survival outcomes.
Table of Contents
Key Features
| Feature | Impact |
|---|---|
| Stochastic nearest-neighbor (NN) matching | Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching |
| G-computation using two-staged, stacked ensemble of heterogeneous learners | Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML |
| Support for multi-arm treatments | Improved multi-arm ESS via stochastic matching |
| Mediation analysis | Plug-in G-computation for interventional mediation effects (IDE/IIE) with binary treatment and binary/continuous mediators and outcomes, supporting bootstrap confidence intervals and optional stochastic matching |
| Support for survival outcomes | Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data |
| Bootstrapped confidence interval (CI) estimation | Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop |
Compatible with scikit-learn |
Maximum flexibility in using ML models by providing access to scikit-learn (and scikit-survival for survival) for propensity-score, outcome and meta-learner stages |
| Full reproducibility of results | Careful implementation of random number generation (RNG) seeding, including in scikit-learn models |
| Available in Python and R | Identical function-centric API in both languages using reticulate; combined with RNG management, leads to identical, reproducible results across platforms |
API
| Function | Brief description |
|---|---|
estimate_te |
Main pipeline – ensemble matching + meta‑learner |
estimate_mediation |
Mediation analysis with plug-in G-computation |
MatchingCATEEstimator |
[Experimental] Individual-level treatment effect estimation (CATE) |
StochasticMatcher |
1:1 nearest‑neighbor matcher (deterministic ↔ stochastic) |
summarize_matching |
Diagnostics: ESS, ASMD, variance ratios, overlap plots |
load_data_lalonde |
Standard Lalonde job‑training dataset (two-arm, continuous outcome) |
load_data_tof |
New simulated Tetralogy of Fallot (ToF) dataset (two-arm or three-arm, survival/binary/continuous outcome) |
⚙️ Installation
pip install causalem
Optional dev extras:
pip install "causalem[dev]"
Minimum Python 3.9. Tested on macOS and Windows.
Package Vignette
For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.
🚀 Quick Start
Two-arm Analysis
Load the necessary packages:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from causalem import (
estimate_te,
load_data_tof,
stochastic_match,
summarize_matching
)
Load the ToF data with two treatment levels and binarized outcome:
X, t, y = load_data_tof(
raw = False,
treat_levels = ['PrP', 'SPS'],
outcome_type="binary",
)
Stochastic matching using propensity scores:
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))
cluster = stochastic_match(
treatment=t,
score=logit_score,
nsmp=10,
scale=1.0,
random_state=0,
)
diag = summarize_matching(
cluster, X,
treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)
TE estimation (includes stochastic matching as the first step, followed by outcome modeling):
res = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
)
print("Two-arm TE:", res["te"])
Multi-arm Analysis
Load data for multi-arm analysis:
df = load_data_tof(
raw = True,
outcome_type="binary",
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()
Constructing propensity scores using multinomial logistic regression:
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
Multi-arm stochastic matching:
cluster_multi = stochastic_match(
treatment=t_all,
score=logit_multi,
nsmp=5,
scale=1.0,
ref_group=ref,
random_state=0,
)
diag_multi = summarize_matching(
cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"]) # dict of counts by group
Multi-arm TE estimation:
res_multi = estimate_te(
X_all,
t_all,
y_all,
outcome_type="binary",
ref_group=ref,
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
Confidence-Interval Calculation
Adding bootstrap CI to the two-arm analysis:
res_boot = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
nboot=200,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])
Heterogeneous Ensemble
learners = [
LogisticRegression(max_iter=1000),
RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
X,
t,
y,
outcome_type="binary",
model_outcome=learners,
niter=len(learners),
do_stacking=True,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])
Stacking vs No-Stacking
# No-stacking: average per-iteration effects without appearance weights
res_ns = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
do_stacking=False,
random_state_master=0,
)
# Stacking: meta-learner fit with appearance weights over the matched union
res_stack = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
do_stacking=True,
random_state_master=0,
)
TE Estimation for Survival Outcomes
X_surv, t_surv, y_surv = load_data_tof(
raw=False
, treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
X_surv,
t_surv,
y_surv,
outcome_type="survival",
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=0,
)
print("Survival HR:", res_surv["te"])
Mediation Analysis
# Load ToF data with mediation structure
from causalem.datasets import load_data_tof
from causalem.mediation import estimate_mediation
# Load ToF data: binary treatment (PrP vs SPS), continuous mediator (op_time), binary outcome
X, A, M, Y = load_data_tof(
raw=False,
treat_levels=['PrP', 'SPS'], # Binary treatment comparison
outcome_type="binary", # Binary outcome for simpler interpretation
include_mediator=True # Include mediator variable (op_time)
)
# Estimate mediation effects
result = estimate_mediation(X, A, M, Y, random_state_master=42)
print("Total Effect (TE):", result["te"])
print("Interventional Direct Effect (IDE):", result["ide"])
print("Interventional Indirect Effect (IIE):", result["iie"])
print("Proportion Mediated:", result["prop_mediated"])
CATE Estimation (Experimental)
⚠️ Experimental Feature: The CATE estimator API is under active development and may change.
Unlike estimate_te() which returns population-level averages, MatchingCATEEstimator predicts individual-level treatment effects:
from causalem._experimental import MatchingCATEEstimator
from causalem import load_data_lalonde
X, t, y = load_data_lalonde(raw=False)
# Initialize and fit the CATE estimator
est = MatchingCATEEstimator(
niter=10,
matching_is_stochastic=True,
matching_scale=1.0,
do_stacking=True,
random_state=42
)
est.fit(X, t, y)
# Get individual treatment effects
individual_effects = est.effect()
print(f"Individual effects range: [{individual_effects.min():.2f}, {individual_effects.max():.2f}]")
# Population summaries
print(f"ATE on matched: {est.ate():.2f}")
print(f"ATT on matched: {est.att():.2f}")
# Identify high-benefit subgroups
import numpy as np
high_benefit_idx = np.where(individual_effects > np.percentile(individual_effects, 75))[0]
print(f"High-benefit group size: {len(high_benefit_idx)}")
For more details, see causalem/_experimental/README.md.
License
This project is licensed under the terms of the MIT License.
Release Notes
1.3.0
New Experimental Feature: CATE (Conditional Average Treatment Effect) Estimation
- Added
MatchingCATEEstimatorclass incausalem._experimentalmodule for individual-level treatment effect prediction - Provides scikit-learn style
fit()/effect()API for learning and predicting heterogeneous treatment effects - Key capabilities:
- Individual-level treatment effect predictions (not just population averages)
- Prediction on new/unseen data
- ATM and ATT estimands supported
- Stochastic and deterministic matching
- Ensemble stacking with meta-learners
- Compatible with heterogeneous base learners
- Current scope (binary treatment, non-survival outcomes):
- ✓ Binary and continuous outcomes
- ✓ Stochastic/deterministic matching
- ✓ Stacking and no-stacking modes
- ✓ ATM/ATT estimands
- Future: Multi-arm treatment, survival outcomes, bootstrap CIs
- Validation: Comprehensive test suite verifying parity with
estimate_te() - Documentation: Detailed design documentation in
causalem/_experimental/README.md - Status: ⚠️ Experimental API - may change in future releases
API Example:
from causalem._experimental import MatchingCATEEstimator
est = MatchingCATEEstimator(niter=10, do_stacking=True, random_state=42)
est.fit(X, t, y)
# Individual effects
effects = est.effect()
# Population summaries
ate = est.ate()
att = est.att()
1.2.0
New Feature: Covariate Inclusion in Stacking Meta-Learner
- Added
include_covariates_in_stackingparameter toestimate_te()to enable including covariates in the meta-learner stage - When
True, covariates are included alongside base learner predictions in the meta-learner design matrix, allowing the meta-learner to learn non-linear combinations of predictions conditional on covariates - Implemented across all pathways: binary, multi-arm, and survival outcomes
- For stacking mode with
do_stacking=True, both base predictions and original covariates are passed to the meta-learner - Defaults to
Falseto preserve backward compatibility - Warning issued if
include_covariates_in_stacking=Truebutdo_stacking=False(parameter has no effect without stacking) - Comprehensive test coverage: 9 new tests covering all outcome types and edge cases
Documentation Enhancement: Heterogeneous Ensembles
- Improved documentation for the existing heterogeneous learner feature in
model_outcomeparameter - Previously feature-complete but undocumented:
model_outcomenow clearly documents support for:- List/tuple of estimators: Mix different model types across iterations (e.g., Random Forest + Gradient Boosting + Linear models)
- Generator/iterator: Dynamically yield different models for each iteration
- Single estimator: Homogeneous ensemble (backward compatible)
- Added practical examples showing heterogeneous ensemble usage with lists and generators
- Documented benefits: improved robustness by combining models with different inductive biases
- Comprehensive test suite added: 22 tests (675 lines) in
tests/test_heterogeneous_learners.pycovering:- All input types (list, tuple, generator)
- All outcome types (continuous, binary, survival)
- Multi-arm treatments
- Error handling (insufficient models, exhausted generators)
- Integration with all features (bootstrap, stacking, covariates, ATT, stochastic matching)
- Reproducibility and comparisons
Bug Fixes:
- Fixed multi-arm stacking to correctly use encoder categories when constructing counterfactual design matrices
API Enhancements:
# Meta-learner uses only base predictions (default)
result = estimate_te(X, t, y, do_stacking=True, include_covariates_in_stacking=False)
# Meta-learner uses both base predictions and covariates
result = estimate_te(X, t, y, do_stacking=True, include_covariates_in_stacking=True)
# Heterogeneous ensemble with different model types
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression
outcome_models = [
RandomForestRegressor(n_estimators=100),
GradientBoostingRegressor(n_estimators=100),
LinearRegression()
]
result = estimate_te(X, t, y, model_outcome=outcome_models, niter=3)
1.1.0
New Feature: Estimand Parameter (ATT vs ATM)
- Added
estimandparameter toestimate_te()andestimate_te_multi()functions - Supports two estimands:
'ATM'(default): Average Treatment Effect on Matched sample - averages over all units appearing in matched sets (preserves backward compatibility)'ATT': Average Treatment Effect on Treated (common support) - averages over treated/ref_group units that were successfully matched
- Implemented across all pathways: binary, multi-arm, and survival outcomes
- For multi-arm with
estimand='ATT',ref_groupparameter specifies which arm is the "treated" group - ATT computes effects on matched treated units only (not all treated), following standard matching literature practice of estimating on the common support
- Comprehensive test coverage: 14 new tests covering all outcome types and pathways
API Enhancement:
# Target effect on matched sample (default)
result = estimate_te(X, t, y, estimand='ATM')
# Target effect on treated population
result = estimate_te(X, t, y, estimand='ATT')
1.0.1
- Removed the R section of
README.mdsince it has not been released yet. - Added release notes for version 1.0.0.
1.0.0
- Removed
binarize_outcomeparameter fromload_data_lalondeandload_data_tof. - Absorbed
load_data_tof_with_mediatorintoload_data_tof.
0.7.0
- Added mediation analysis functionality with
estimate_mediationfunction for interventional mediation effects using plug-in G-computation. - Supports binary treatment with binary/continuous mediators and continuous outcomes.
- Features bootstrap confidence intervals and optional integration with stochastic matching for improved robustness.
- Estimates total effect (TE), interventional direct effect (IDE), and interventional indirect effect (IIE).
0.6.2
- Exposed a new
n_mcargument in estimate_te for specifying Monte‑Carlo draws per matched unit in survival analyses, replacing the previously fixed single draw. - Clarified treatment‑effect estimands for stacking vs. no‑stacking modes, noting that stacked results are appearance‑weighted across the matched union.
- Documented appearance‑weighted meta‑learning and matched‑union survival contrasts.
0.6.1
- Corrected the version number in
pyproject.tomlfile.
0.6.0
- Improved consistency of return data structure when
do_stacking=Falsein multi-arm TE estimation.
0.5.4
- Added github action for publishing to PyPI
0.5.3
- First public release
0.5.1
- Edits to readme
- Added github action for publishing to (test) PyPI
0.5.0
- First test release
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file causalem-1.3.0.tar.gz.
File metadata
- Download URL: causalem-1.3.0.tar.gz
- Upload date:
- Size: 187.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3ba218d25b89e81ab647c0de622b709c74949f9be80136ea3997d696d98c7c0e
|
|
| MD5 |
5e5dd59c410b294c312cc9973297518c
|
|
| BLAKE2b-256 |
8ff20a6cfdc1e415f76d7169765d54b0f7a162bc2f30c7b75d531d97816e3583
|
File details
Details for the file causalem-1.3.0-py3-none-any.whl.
File metadata
- Download URL: causalem-1.3.0-py3-none-any.whl
- Upload date:
- Size: 157.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
73e307ecac4bbbde16f69d8020f745c389adebb903c103bd3e1cb677f98aaa31
|
|
| MD5 |
017c43bc1fcb90e64492feea2ddb8b11
|
|
| BLAKE2b-256 |
367f4741f3621ea6bb4c2ed7261cfda7ae9fa88841864a22508ef4f2f6dd86aa
|