Causal Inference via Ensemble G-Computing with Stochastic Matching
Reason this release was yanked:
Erroneous 'In Development' tag was added to release notes.
Project description
CausalEM – Ensemble G-Computing with Stochastic Matching for Causal Inference
CausalEM implements ensemble G-computing with stochastic matching for treatment effect estimation and mediation analysis. Two methodological innovations - stochastic matching and two-stage ensemble G-computing - combine the diagnostic transparency of matching with the modeling flexibility and robustness of ensemble machine learning. The package supports multi-arm treatments, continuous/binary/survival outcomes, and provides a scikit-learn-compatible API.
Table of Contents
Key Features
| Feature | Impact |
|---|---|
| Stochastic nearest-neighbor (NN) matching | Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching |
| Two-stage ensemble G-computing | Implements G-computing via stacked meta-learning with frequency-weighted aggregation of heterogeneous base learners; cross-fitting of propensity, outcome, and meta-learner models |
| Scalable to large datasets | Memory-efficient matching algorithm handles 100k+ observations on standard hardware (36 GB RAM); tested on datasets up to 220k rows |
| Support for multi-arm treatments | Improved multi-arm ESS via stochastic matching |
| Mediation analysis | Plug-in G-computation for interventional mediation effects (IDE/IIE) with binary treatment and binary/continuous mediators and outcomes, supporting bootstrap confidence intervals and optional stochastic matching |
| Support for survival outcomes | Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data |
| Bootstrapped confidence interval (CI) estimation | Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop |
Compatible with scikit-learn |
Maximum flexibility in using ML models by providing access to scikit-learn (and scikit-survival for survival) for propensity-score, outcome and meta-learner stages |
| Full reproducibility of results | Careful implementation of random number generation (RNG) seeding, including in scikit-learn models |
API
| Function | Brief description |
|---|---|
estimate_te |
Main pipeline – ensemble matching + meta‑learner |
estimate_mediation |
Mediation analysis with plug-in G-computation |
MatchingCATEEstimator |
[Experimental] Individual-level treatment effect estimation (CATE) |
stochastic_matching |
Stochastic/deterministic nearest-neighbor matching |
summarize_matching |
Diagnostics: ESS, ASMD, variance ratios, overlap plots |
load_data_lalonde |
Standard Lalonde job‑training dataset (two-arm, continuous outcome) |
load_data_tof |
New simulated Tetralogy of Fallot (ToF) dataset (two-arm or three-arm, survival/binary/continuous outcome) |
Installation
pip install causalem
Optional dev extras:
pip install "causalem[dev]"
Minimum Python 3.9. Tested on macOS and Windows.
Package Vignette
For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.
Quick Start
Two-arm Analysis
Load the necessary packages:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from causalem import (
estimate_te,
load_data_tof,
stochastic_matching,
summarize_matching
)
Load the ToF data with two treatment levels and binarized outcome:
X, t, y = load_data_tof(
raw = False,
treat_levels = ['PrP', 'SPS'],
outcome_type="binary",
)
Stochastic matching using propensity scores:
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))
cluster = stochastic_matching(
treatment=t,
score=logit_score,
nsmp=10,
scale=1.0,
random_state=0,
)
diag = summarize_matching(
cluster, X,
treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)
TE estimation (includes stochastic matching as the first step, followed by outcome modeling):
res = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
random_state_master=1,
)
print("Two-arm TE:", res["te"])
Uses v2.0.0 defaults: std_matching=True, matching_caliper=0.2, matching_scale=0.8 (in units of SD of logit PS).
Multi-arm Analysis
Load data for multi-arm analysis:
df = load_data_tof(
raw = True,
outcome_type="binary",
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()
Constructing propensity scores using multinomial logistic regression:
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
Multi-arm stochastic matching:
cluster_multi = stochastic_matching(
treatment=t_all,
score=logit_multi,
nsmp=5,
scale=1.0,
ref_group=ref,
random_state=0,
)
diag_multi = summarize_matching(
cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"]) # dict of counts by group
Multi-arm TE estimation:
res_multi = estimate_te(
X_all,
t_all,
y_all,
outcome_type="binary",
ref_group=ref,
niter=5,
random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
Uses v2.0.0 defaults with SD-based matching parameters.
Confidence-Interval Calculation
Adding bootstrap CI to the two-arm analysis:
res_boot = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
nboot=200,
random_state_master=1,
random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])
Heterogeneous Ensemble
learners = [
LogisticRegression(max_iter=1000),
RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
X,
t,
y,
outcome_type="binary",
model_outcome=learners,
niter=len(learners),
do_stacking=True,
random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])
Stacking vs No-Stacking
# No-stacking: average per-iteration effects without appearance weights
res_ns = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
do_stacking=False,
random_state_master=0,
)
# Stacking: meta-learner fit with appearance weights over the matched union
res_stack = estimate_te(
X,
t,
y,
outcome_type="binary",
niter=5,
do_stacking=True,
random_state_master=0,
)
TE Estimation for Survival Outcomes
X_surv, t_surv, y_surv = load_data_tof(
raw=False
, treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
X_surv,
t_surv,
y_surv,
outcome_type="survival",
niter=5,
random_state_master=0,
)
print("Survival HR:", res_surv["te"])
Mediation Analysis
# Load ToF data with mediation structure
from causalem.datasets import load_data_tof
from causalem.mediation import estimate_mediation
# Load ToF data: binary treatment (PrP vs SPS), continuous mediator (op_time), binary outcome
X, A, M, Y = load_data_tof(
raw=False,
treat_levels=['PrP', 'SPS'], # Binary treatment comparison
outcome_type="binary", # Binary outcome for simpler interpretation
include_mediator=True # Include mediator variable (op_time)
)
# Estimate mediation effects
result = estimate_mediation(X, A, M, Y, random_state_master=42)
print("Total Effect (TE):", result["te"])
print("Interventional Direct Effect (IDE):", result["ide"])
print("Interventional Indirect Effect (IIE):", result["iie"])
print("Proportion Mediated:", result["prop_mediated"])
CATE Estimation (Experimental)
Experimental Feature: The CATE estimator API is under active development and may change.
Unlike estimate_te() which returns population-level averages, MatchingCATEEstimator predicts individual-level treatment effects:
from causalem._experimental import MatchingCATEEstimator
from causalem import load_data_lalonde
X, t, y = load_data_lalonde(raw=False)
# Initialize and fit the CATE estimator
est = MatchingCATEEstimator(
niter=10,
do_stacking=True,
random_state=42
)
est.fit(X, t, y)
# Get individual treatment effects
individual_effects = est.effect()
print(f"Individual effects range: [{individual_effects.min():.2f}, {individual_effects.max():.2f}]")
# Population summaries
print(f"ATE on matched: {est.ate():.2f}")
print(f"ATT on matched: {est.att():.2f}")
# Identify high-benefit subgroups
import numpy as np
high_benefit_idx = np.where(individual_effects > np.percentile(individual_effects, 75))[0]
print(f"High-benefit group size: {len(high_benefit_idx)}")
For more details, see causalem/_experimental/README.md.
License
This project is licensed under the terms of the MIT License.
Release Notes
2.0.1 (In Development)
Documentation: Updated package branding to emphasize ensemble G-computing framework and clarified the two core innovations (stochastic matching and two-stage ensemble G-computing).
No functional changes - all algorithms and results identical to 2.0.0.
2.0.0
Performance Improvements:
-
Major scalability enhancement for large datasets (100k+ observations)
- Memory optimization: Matching algorithm now computes distances on-the-fly rather than pre-computing full n×n distance matrix
- Memory reduction: O(n²) → O(n) complexity enables matching on datasets with 100k+ observations on standard hardware
- Tested at scale: Successfully validated on 181k observations (real orthopedic registry data) and synthetic datasets up to 220k rows
- Speed improvement: ~5-10× faster with caliper parameter (default 0.2 SD)
- Backward compatibility: Results are bit-identical to previous implementation when using same random seed
- Implementation: Modified internal matching functions to accept propensity score arrays directly instead of distance matrices
- Legacy support: Pre-computed distance matrix input preserved for backward compatibility (not recommended for large n)
-
Scalability validation
- Previous limitation: MemoryError at ~50k observations (requires 20 GB RAM for distance matrix)
- New capability: Handles 200k+ observations with <5 GB peak memory usage
- Validated on ACIC2019 benchmark: Bit-identical results to original implementation
- Note: Computational time still scales approximately O(n²) for matching phase, but memory bottleneck removed
Breaking Changes:
-
groupsparameter renamed to_cv_groupsinestimate_te(),estimate_te_multi(), andestimate_mediation()- Now marked as internal-only with underscore prefix
- Migration:
groups=my_groups→_cv_groups=my_groups
-
stochastic_match()function renamed tostochastic_matching()- Improves naming consistency across the API (both
stochastic_matchingandsummarize_matchingnow use "matching") - Migration: Update all imports and function calls from
stochastic_matchtostochastic_matching
- Improves naming consistency across the API (both
-
Removed natural effects from mediation analysis
effect_typeparameter removed fromestimate_mediation()- Only interventional effects (IDE/IIE) are now returned
- Natural effects were misleading when using flexible ML models that can learn treatment-mediator interactions
- Migration: Remove
effect_type="natural"arguments; code that relied on NDE/NIE keys will need to use IDE/IIE instead
-
SD-based caliper and scale for
estimate_mediation()(completing API harmonization withestimate_te())- Added
std_matchingparameter (defaultTrue) for SD-based interpretation of caliper/scale - Changed
matching_caliperdefault:None→0.2 - Changed
matching_scaledefault:1.0→0.8 - With
std_matching=True, caliper and scale are interpreted as multiples of SD(logit propensity scores) - Benefits: Automatic adaptation to data spread, robust matching across diverse datasets
- Migration: To preserve v1.x behavior, set
std_matching=False, matching_scale=1.0, matching_caliper=None
- Added
-
matching_is_stochasticparameter replaced withmatching_methodinestimate_mediation()- Old parameter was boolean; new parameter is string with three options:
"stochastic","deterministic", orNone - Migration:
matching_is_stochastic=True→matching_method="stochastic"matching_is_stochastic=False, niter=1→matching_method=None, niter=1matching_is_stochastic=False, niter>1→matching_method="deterministic"
- Old parameter was boolean; new parameter is string with three options:
-
New defaults in
estimate_mediation()now use stochastic matching by defaultniterdefault:1→10matching_methoddefault: effectivelyNone→"stochastic"- Migration: To preserve v1.x behavior, explicitly set
matching_method=None, niter=1
-
SD-based caliper and scale specification in
estimate_te()andestimate_te_multi()- Added
std_matchingparameter (defaultTrue) enabling standard deviation-based units matching_caliperdefault:None→0.2(0.2 SD whenstd_matching=True)matching_scaledefault:1.0→0.8(0.8 SD whenstd_matching=True)- When
std_matching=True(default): caliper and scale interpreted as multiples of SD(logit propensity scores) - When
std_matching=False: use absolute values (v1.x behavior) - Migration: To preserve v1.x behavior, set
std_matching=False, matching_scale=1.0, matching_caliper=None - Rationale: 0.2 SD is recommended in matching literature (Austin 2011, Rosenbaum & Rubin 1985) and adapts to data spread
- Added
-
Changed default return format for dataset loaders
load_data_tof()andload_data_lalonde()now default toraw=False- These functions now return
(X, t, y)or(X, t, m, y)arrays by default instead of DataFrames - Migration: To get DataFrame output (v1.x behavior), explicitly pass
raw=True - Rationale: Array output aligns with how these functions are predominantly used in practice
New Features:
- SD-based matching parameters:
std_matchingparameter enables adaptive caliper and scale based on propensity score spread - Added
estimandparameter toestimate_mediation()supporting"ATM"(default) and"ATT" - Added
n_splits_propensityparameter toestimate_mediation()(default: 5) for propensity model cross-fitting - Updated
MatchingCATEEstimator(experimental) to supportstd_matchingparameter with new defaults
Bug Fixes:
- Fixed
prob_clip_epsparameter not being passed through in survival pathway - Improved parameter handling in
_estimate_te_survival_single_iter()
Documentation:
- Clarified that meta-learner is cross-fitted using
n_splits_outcomefolds in addition to base models
1.3.0
New Experimental Feature: CATE (Conditional Average Treatment Effect) Estimation
- Added
MatchingCATEEstimatorclass incausalem._experimentalmodule for individual-level treatment effect prediction - Provides scikit-learn style
fit()/effect()API for learning and predicting heterogeneous treatment effects - Key capabilities:
- Individual-level treatment effect predictions (not just population averages)
- Prediction on new/unseen data
- ATM and ATT estimands supported
- Stochastic and deterministic matching
- Ensemble stacking with meta-learners
- Compatible with heterogeneous base learners
- Current scope (binary treatment, non-survival outcomes):
- ✓ Binary and continuous outcomes
- ✓ Stochastic/deterministic matching
- ✓ Stacking and no-stacking modes
- ✓ ATM/ATT estimands
- Future: Multi-arm treatment, survival outcomes, bootstrap CIs
- Validation: Comprehensive test suite verifying parity with
estimate_te() - Documentation: Detailed design documentation in
causalem/_experimental/README.md - Status: Experimental API - may change in future releases
API Example:
from causalem._experimental import MatchingCATEEstimator
est = MatchingCATEEstimator(niter=10, do_stacking=True, random_state=42)
est.fit(X, t, y)
# Individual effects
effects = est.effect()
# Population summaries
ate = est.ate()
att = est.att()
1.2.0
New Feature: Covariate Inclusion in Stacking Meta-Learner
- Added
include_covariates_in_stackingparameter toestimate_te()to enable including covariates in the meta-learner stage - When
True, covariates are included alongside base learner predictions in the meta-learner design matrix, allowing the meta-learner to learn non-linear combinations of predictions conditional on covariates - Implemented across all pathways: binary, multi-arm, and survival outcomes
- For stacking mode with
do_stacking=True, both base predictions and original covariates are passed to the meta-learner - Defaults to
Falseto preserve backward compatibility - Warning issued if
include_covariates_in_stacking=Truebutdo_stacking=False(parameter has no effect without stacking) - Comprehensive test coverage: 9 new tests covering all outcome types and edge cases
Documentation Enhancement: Heterogeneous Ensembles
- Improved documentation for the existing heterogeneous learner feature in
model_outcomeparameter - Previously feature-complete but undocumented:
model_outcomenow clearly documents support for:- List/tuple of estimators: Mix different model types across iterations (e.g., Random Forest + Gradient Boosting + Linear models)
- Generator/iterator: Dynamically yield different models for each iteration
- Single estimator: Homogeneous ensemble (backward compatible)
- Added practical examples showing heterogeneous ensemble usage with lists and generators
- Documented benefits: improved robustness by combining models with different inductive biases
- Comprehensive test suite added: 22 tests (675 lines) in
tests/test_heterogeneous_learners.pycovering:- All input types (list, tuple, generator)
- All outcome types (continuous, binary, survival)
- Multi-arm treatments
- Error handling (insufficient models, exhausted generators)
- Integration with all features (bootstrap, stacking, covariates, ATT, stochastic matching)
- Reproducibility and comparisons
Bug Fixes:
- Fixed multi-arm stacking to correctly use encoder categories when constructing counterfactual design matrices
API Enhancements:
# Meta-learner uses only base predictions (default)
result = estimate_te(X, t, y, do_stacking=True, include_covariates_in_stacking=False)
# Meta-learner uses both base predictions and covariates
result = estimate_te(X, t, y, do_stacking=True, include_covariates_in_stacking=True)
# Heterogeneous ensemble with different model types
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression
outcome_models = [
RandomForestRegressor(n_estimators=100),
GradientBoostingRegressor(n_estimators=100),
LinearRegression()
]
result = estimate_te(X, t, y, model_outcome=outcome_models, niter=3)
1.1.0
New Feature: Estimand Parameter (ATT vs ATM)
- Added
estimandparameter toestimate_te()andestimate_te_multi()functions - Supports two estimands:
'ATM'(default): Average Treatment Effect on Matched sample - averages over all units appearing in matched sets (preserves backward compatibility)'ATT': Average Treatment Effect on Treated (common support) - averages over treated/ref_group units that were successfully matched
- Implemented across all pathways: binary, multi-arm, and survival outcomes
- For multi-arm with
estimand='ATT',ref_groupparameter specifies which arm is the "treated" group - ATT computes effects on matched treated units only (not all treated), following standard matching literature practice of estimating on the common support
- Comprehensive test coverage: 14 new tests covering all outcome types and pathways
API Enhancement:
# Target effect on matched sample (default)
result = estimate_te(X, t, y, estimand='ATM')
# Target effect on treated population
result = estimate_te(X, t, y, estimand='ATT')
1.0.1
- Removed the R section of
README.mdsince it has not been released yet. - Added release notes for version 1.0.0.
1.0.0
- Removed
binarize_outcomeparameter fromload_data_lalondeandload_data_tof. - Absorbed
load_data_tof_with_mediatorintoload_data_tof.
0.7.0
- Added mediation analysis functionality with
estimate_mediationfunction for interventional mediation effects using plug-in G-computation. - Supports binary treatment with binary/continuous mediators and continuous outcomes.
- Features bootstrap confidence intervals and optional integration with stochastic matching for improved robustness.
- Estimates total effect (TE), interventional direct effect (IDE), and interventional indirect effect (IIE).
0.6.2
- Exposed a new
n_mcargument in estimate_te for specifying Monte‑Carlo draws per matched unit in survival analyses, replacing the previously fixed single draw. - Clarified treatment‑effect estimands for stacking vs. no‑stacking modes, noting that stacked results are appearance‑weighted across the matched union.
- Documented appearance‑weighted meta‑learning and matched‑union survival contrasts.
0.6.1
- Corrected the version number in
pyproject.tomlfile.
0.6.0
- Improved consistency of return data structure when
do_stacking=Falsein multi-arm TE estimation.
0.5.4
- Added github action for publishing to PyPI
0.5.3
- First public release
0.5.1
- Edits to readme
- Added github action for publishing to (test) PyPI
0.5.0
- First test release
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file causalem-2.0.1.tar.gz.
File metadata
- Download URL: causalem-2.0.1.tar.gz
- Upload date:
- Size: 164.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
59c1e7b60b2827148a1784a297221a5d4067cf0eb4d0d9fdc2e04f651f6ba736
|
|
| MD5 |
59e8574a396ddca81e405a9e847145cd
|
|
| BLAKE2b-256 |
9e552d5bc07a1d4b92efc3452971efdeccef6c2110bd3eda02d02d0f01956a09
|
File details
Details for the file causalem-2.0.1-py3-none-any.whl.
File metadata
- Download URL: causalem-2.0.1-py3-none-any.whl
- Upload date:
- Size: 127.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1de8edfc9f2f0145346278601bce33240f528d77a804fdbd8711c0c287090fa2
|
|
| MD5 |
7bfb4e9462d10c178bfa061a05dc88d1
|
|
| BLAKE2b-256 |
0186971470882261d4c8ed495c8108e2448314e4b3f7eaf8b0ab8b99187f7309
|