Causal Inference using Ensemble Matching
Project description
CausalEM – Ensemble Matching for Causal Inference
CausalEM is an ensemble‑based toolbox for multi-arm treatment‑effect estimation using stochastic matching, with support for continuous, binary, and right-censored time-to-event (survival) outcomes.
Key Features
- Stochastic adaptation of nearest-neighbor (NN) matching -> Larger effective sample size (ESS) and improved TE estimation accuracy vs. standard (deterministic) NN matching.
- G-computation using two-staged, stacked ensemble of hetrogeneous learners -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
- Support for multi-arm treatments -> Stochastic matching in
CausalEMcan be especially helpful in multi-arm scenarios for improving ESS. - Support for survival outcomes -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
- Bootstrapped confidence interval (CI) estimation -> Honest estimates of CI by including entire matching + TE estimation pipeline in bootstrap loop.
- Compatible with
scikit-learn-> Maximum flexibility in using machine learning by providing access toscikit-learnmodels for propensity-score, outcome and meta-learner stages (scikit-survivalfor survival outcomes). - Full reproducibility of results --> Careful implementation of seeding for random number generation (RNG), including in
scikit-learnmodels.
API
| Function | Brief description |
|---|---|
estimate_te |
Main pipeline – ensemble matching + meta‑learner |
StochasticMatcher |
1:1 nearest‑neighbor matcher (deterministic ↔ stochastic) |
summarize_matching |
Diagnostics: ESS, ASMD, variance ratios, overlap plots |
load_data_lalonde |
Copy of Lalonde job‑training dataset |
load_data_tof |
Simulated TOF dataset (survival or binary outcome) |
⚙️ Installation
pip install causalem
Optional dev extras:
pip install "causalem[dev]"
Minimum Python 3.9. Tested on macOS and Windows.
Package Vignette
For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.
🚀 Quick Start
Two-arm Analysis
Load the necessary packages:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from causalem import (
estimate_te,
load_data_tof,
stochastic_match,
summarize_matching
)
Load the ToF data with two treatment levels and binarized outcome:
X, t, y = load_data_tof(
raw = False,
treat_levels = ['PrP', 'SPS'],
binarize_outcome=True,
)
Stochastic matching using propensity scores:
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))
cluster = stochastic_match(
treatment=t,
score=logit_score,
nsmp=10,
scale=1.0,
random_state=0,
)
diag = summarize_matching(
cluster, X,
treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)
TE estimation:
res = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
)
print("Two-arm TE:", res["te"])
Multi-arm Analysis
Load data for multi-arm analysis:
df = load_data_tof(
raw = True,
binarize_outcome=True,
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()
Constructing propensity scores using multinomial logistic regression:
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
Multi-arm stochastic matching:
cluster_multi = stochastic_match(
treatment=t_all,
score=logit_multi,
nsmp=5,
scale=1.0,
ref_group=ref,
random_state=0,
)
diag_multi = summarize_matching(
cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])
Multi-arm TE estimation:
res_multi = estimate_te(
X_all,
t_all,
y_all,
outcome_type="binary",
ref_group=ref,
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
Confidence-Interval Calculation
Adding bootstrap CI to the two-arm analysis:
res_boot = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
nboot=200,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=1,
random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])
Heterogeneous Ensemble
learners = [
LogisticRegression(max_iter=1000),
RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
X,
t,
y,
outcome_type="binary",
model_outcome=learners,
niter=len(learners),
do_stacking=True,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])
TE Estimation for Survival Outcomes
X_surv, t_surv, y_surv = load_data_tof(
raw=False
, treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
X_surv,
t_surv,
y_surv,
outcome_type="survival",
niter=5,
matching_scale=1.0,
matching_is_stochastic=True,
random_state_master=0,
)
print("Survival HR:", res_surv["te"])
License
This project is licensed under the terms of the MIT License – see the LICENSE file.
Release Notes
0.5.0
- First public release
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file causalem-0.5.0.tar.gz.
File metadata
- Download URL: causalem-0.5.0.tar.gz
- Upload date:
- Size: 76.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
663e4f287a6bc6db3b1ab52c9e8d7a36987d777ab13ad830ac0a0d218cee9829
|
|
| MD5 |
a3081921146db705f9ba352441ed8f0e
|
|
| BLAKE2b-256 |
34a9159335e08422e7bbe00023d348542dc8ad984573e43469ea5680308cf0a7
|
File details
Details for the file causalem-0.5.0-py3-none-any.whl.
File metadata
- Download URL: causalem-0.5.0-py3-none-any.whl
- Upload date:
- Size: 69.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2886541efeb0647b8b4751f964bdcb2d34782cd487d6821508d93b87221abb3d
|
|
| MD5 |
a27676548f865611c71c003b5040f9f2
|
|
| BLAKE2b-256 |
a96da4357bdb6bbec7e62cccb76f648143281ba6cd1076557c9b0dadcb61e165
|