Causal Inference using Ensemble Matching
Project description
CausalEM – Ensemble Matching for Causal Inference
CausalEM is a toolbox for multi-arm treatment‑effect estimation using stochastic matching and a stacked ensemble of heterogeneous ML models. It supports continuous, binary, and survival outcomes.
Key Features
- Stochastic nearest-neighbor (NN) matching -> Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching.
- G-computation using two-staged, stacked ensemble of hetrogeneous learners -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
- Support for multi-arm treatments -> Improved multi-arm ESS via stochastic matching.
- Mediation analysis -> Plug-in G-computation for interventional mediation effects (IDE/IIE) with binary treatment and binary/continuous mediators and outcomes, supporting bootstrap confidence intervals and optional stochastic matching.
- Support for survival outcomes -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
- Bootstrapped confidence interval (CI) estimation -> Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop.
- Compatible with
scikit-learn-> Maximum flexibility in using ML models by providing access toscikit-learn(andscikit-survivalfor survival) for propensity-score, outcome and meta-learner stages. - Full reproducibility of results --> Careful implementation of random number generation (RNG) seeding, including in
scikit-learnmodels.
API
| Function | Brief description |
|---|---|
estimate_te |
Main pipeline – ensemble matching + meta‑learner |
estimate_mediation |
Mediation analysis with plug-in G-computation |
StochasticMatcher |
1:1 nearest‑neighbor matcher (deterministic ↔ stochastic) |
summarize_matching |
Diagnostics: ESS, ASMD, variance ratios, overlap plots |
load_data_lalonde |
Copy of Lalonde job‑training dataset |
load_data_tof |
Simulated TOF dataset (survival or binary outcome) |
⚙️ Installation
pip install causalem
Optional dev extras:
pip install "causalem[dev]"
Minimum Python 3.9. Tested on macOS and Windows.
Package Vignette
For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.
🚀 Quick Start
Two-arm Analysis
Load the necessary packages:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from causalem import (
estimate_te,
load_data_tof,
stochastic_match,
summarize_matching
)
Load the ToF data with two treatment levels and binarized outcome:
X, t, y = load_data_tof(
raw = False,
treat_levels = ['PrP', 'SPS'],
outcome_type="binary",
)
Stochastic matching using propensity scores:
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))
cluster = stochastic_match(
treatment=t,
score=logit_score,
nsmp=10,
scale=1.0,
random_state=0,
)
diag = summarize_matching(
cluster, X,
treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)
TE estimation:
res = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
)
print("Two-arm TE:", res["te"])
Multi-arm Analysis
Load data for multi-arm analysis:
df = load_data_tof(
raw = True,
outcome_type="binary",
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()
Constructing propensity scores using multinomial logistic regression:
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
Multi-arm stochastic matching:
cluster_multi = stochastic_match(
treatment=t_all,
score=logit_multi,
nsmp=5,
scale=1.0,
ref_group=ref,
random_state=0,
)
diag_multi = summarize_matching(
cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"]) # dict of counts by group
Multi-arm TE estimation:
res_multi = estimate_te(
X_all,
t_all,
y_all,
outcome_type="binary",
ref_group=ref,
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
Confidence-Interval Calculation
Adding bootstrap CI to the two-arm analysis:
res_boot = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
nboot=200,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])
Heterogeneous Ensemble
learners = [
LogisticRegression(max_iter=1000),
RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
X,
t,
y,
outcome_type="binary",
model_outcome=learners,
niter=len(learners),
do_stacking=True,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])
Stacking vs No-Stacking
# No-stacking: average per-iteration effects without appearance weights
res_ns = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
do_stacking=False,
random_state_master=0,
)
# Stacking: meta-learner fit with appearance weights over the matched union
res_stack = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
do_stacking=True,
random_state_master=0,
)
TE Estimation for Survival Outcomes
X_surv, t_surv, y_surv = load_data_tof(
raw=False
, treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
X_surv,
t_surv,
y_surv,
outcome_type="survival",
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=0,
)
print("Survival HR:", res_surv["te"])
Mediation Analysis
# Load ToF data with mediation structure
from causalem.datasets import load_data_tof
from causalem.mediation import estimate_mediation
# Load ToF data: binary treatment (PrP vs SPS), continuous mediator (op_time), binary outcome
X, A, M, Y = load_data_tof(
raw=False,
treat_levels=['PrP', 'SPS'], # Binary treatment comparison
outcome_type="binary", # Binary outcome for simpler interpretation
include_mediator=True # Include mediator variable (op_time)
)
# Estimate mediation effects
result = estimate_mediation(X, A, M, Y, random_state_master=42)
print("Total Effect (TE):", result["te"])
print("Interventional Direct Effect (IDE):", result["ide"])
print("Interventional Indirect Effect (IIE):", result["iie"])
print("Proportion Mediated:", result["prop_mediated"])
License
This project is licensed under the terms of the MIT License.
Release Notes
0.7.0
- Added mediation analysis functionality with
estimate_mediationfunction for interventional mediation effects using plug-in G-computation. - Supports binary treatment with binary/continuous mediators and continuous outcomes.
- Features bootstrap confidence intervals and optional integration with stochastic matching for improved robustness.
- Estimates total effect (TE), interventional direct effect (IDE), and interventional indirect effect (IIE).
0.6.2
- Exposed a new
n_mcargument in estimate_te for specifying Monte‑Carlo draws per matched unit in survival analyses, replacing the previously fixed single draw. - Clarified treatment‑effect estimands for stacking vs. no‑stacking modes, noting that stacked results are appearance‑weighted across the matched union.
- Documented appearance‑weighted meta‑learning and matched‑union survival contrasts.
0.6.1
- Corrected the version number in
pyproject.tomlfile.
0.6.0
- Improved consistency of return data structure when
do_stacking=Falsein multi-arm TE estimation.
0.5.4
- Added github action for publishing to PyPI
0.5.3
- First public release
0.5.1
- Edits to readme
- Added github action for publishing to (test) PyPI
0.5.0
- First test release
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file causalem-0.7.0.tar.gz.
File metadata
- Download URL: causalem-0.7.0.tar.gz
- Upload date:
- Size: 170.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a73afaf8256fda7f5251db6b1e15f45dad1d32b468ef00ed30fc1c9b6f6ef31d
|
|
| MD5 |
1450d022d3699e2875fc760bd879f81b
|
|
| BLAKE2b-256 |
3b1d3296abc537db0d6238692b74e197d07b1a5a93ddc51683e80fc00c831fea
|
File details
Details for the file causalem-0.7.0-py3-none-any.whl.
File metadata
- Download URL: causalem-0.7.0-py3-none-any.whl
- Upload date:
- Size: 147.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4a0a156b42c841afade0e270e3928aca498769cc8b32b5e6b65d92f37b61651c
|
|
| MD5 |
1bfcc88c61bd113118b0708e2adbeb33
|
|
| BLAKE2b-256 |
4dd5864ae01a45d4c1b4d080bc3441a2ea4cc987d3541d0d7bbd7287d7ab2ac3
|