Skip to main content

Causal Inference using Ensemble Matching

Project description

CausalEM – Ensemble Matching for Causal Inference

CausalEM is a toolbox for multi-arm treatment‑effect estimation and mediation analysis using stochastic matching and a stacked ensemble of heterogeneous ML models. It supports continuous, binary, and survival outcomes.

Table of Contents


Key Features

Feature Impact
Stochastic nearest-neighbor (NN) matching Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching
G-computation using two-staged, stacked ensemble of heterogeneous learners Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML
Scalable to large datasets Memory-efficient matching algorithm handles 100k+ observations on standard hardware (36 GB RAM); tested on datasets up to 220k rows
Support for multi-arm treatments Improved multi-arm ESS via stochastic matching
Mediation analysis Plug-in G-computation for interventional mediation effects (IDE/IIE) with binary treatment and binary/continuous mediators and outcomes, supporting bootstrap confidence intervals and optional stochastic matching
Support for survival outcomes Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data
Bootstrapped confidence interval (CI) estimation Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop
Compatible with scikit-learn Maximum flexibility in using ML models by providing access to scikit-learn (and scikit-survival for survival) for propensity-score, outcome and meta-learner stages
Full reproducibility of results Careful implementation of random number generation (RNG) seeding, including in scikit-learn models

API

Function Brief description
estimate_te Main pipeline – ensemble matching + meta‑learner
estimate_mediation Mediation analysis with plug-in G-computation
MatchingCATEEstimator [Experimental] Individual-level treatment effect estimation (CATE)
stochastic_matching Stochastic/deterministic nearest-neighbor matching
summarize_matching Diagnostics: ESS, ASMD, variance ratios, overlap plots
load_data_lalonde Standard Lalonde job‑training dataset (two-arm, continuous outcome)
load_data_tof New simulated Tetralogy of Fallot (ToF) dataset (two-arm or three-arm, survival/binary/continuous outcome)

⚙️ Installation

pip install causalem

Optional dev extras:

pip install "causalem[dev]"

Minimum Python 3.9. Tested on macOS and Windows.


Package Vignette

For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.


🚀 Quick Start

Two-arm Analysis

Load the necessary packages:

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from causalem import (
  estimate_te,
  load_data_tof,
  stochastic_matching,
  summarize_matching
)

Load the ToF data with two treatment levels and binarized outcome:

X, t, y = load_data_tof(
  raw = False,
  treat_levels = ['PrP', 'SPS'],
  outcome_type="binary",
)

Stochastic matching using propensity scores:

lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))

cluster = stochastic_matching(
    treatment=t,
    score=logit_score,
    nsmp=10,
    scale=1.0,
    random_state=0,
)

diag = summarize_matching(
  cluster, X,
  treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)

TE estimation (includes stochastic matching as the first step, followed by outcome modeling):

res = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    random_state_master=1,
)
print("Two-arm TE:", res["te"])

Uses v2.0.0 defaults: std_matching=True, matching_caliper=0.2, matching_scale=0.8 (in units of SD of logit PS).

Multi-arm Analysis

Load data for multi-arm analysis:

df = load_data_tof(
  raw = True,
  outcome_type="binary",
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()

Constructing propensity scores using multinomial logistic regression:

lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))

Multi-arm stochastic matching:

cluster_multi = stochastic_matching(
    treatment=t_all,
    score=logit_multi,
    nsmp=5,
    scale=1.0,
    ref_group=ref,
    random_state=0,
)
diag_multi = summarize_matching(
    cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])  # dict of counts by group

Multi-arm TE estimation:

res_multi = estimate_te(
    X_all,
    t_all,
    y_all,
    outcome_type="binary",
    ref_group=ref,
    niter=5,
    random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])

Uses v2.0.0 defaults with SD-based matching parameters.

Confidence-Interval Calculation

Adding bootstrap CI to the two-arm analysis:

res_boot = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    nboot=200,
    random_state_master=1,
    random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])

Heterogeneous Ensemble

learners = [
    LogisticRegression(max_iter=1000),
    RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    model_outcome=learners,
    niter=len(learners),
    do_stacking=True,
    random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])

Stacking vs No-Stacking

# No-stacking: average per-iteration effects without appearance weights
res_ns = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    do_stacking=False,
    random_state_master=0,
)

# Stacking: meta-learner fit with appearance weights over the matched union
res_stack = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    do_stacking=True,
    random_state_master=0,
)

TE Estimation for Survival Outcomes

X_surv, t_surv, y_surv = load_data_tof(
  raw=False
  , treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
    X_surv,
    t_surv,
    y_surv,
    outcome_type="survival",
    niter=5,
    random_state_master=0,
)
print("Survival HR:", res_surv["te"])

Mediation Analysis

# Load ToF data with mediation structure
from causalem.datasets import load_data_tof
from causalem.mediation import estimate_mediation

# Load ToF data: binary treatment (PrP vs SPS), continuous mediator (op_time), binary outcome
X, A, M, Y = load_data_tof(
    raw=False,
    treat_levels=['PrP', 'SPS'],  # Binary treatment comparison
    outcome_type="binary",        # Binary outcome for simpler interpretation
    include_mediator=True         # Include mediator variable (op_time)
)

# Estimate mediation effects
result = estimate_mediation(X, A, M, Y, random_state_master=42)

print("Total Effect (TE):", result["te"])
print("Interventional Direct Effect (IDE):", result["ide"])
print("Interventional Indirect Effect (IIE):", result["iie"])
print("Proportion Mediated:", result["prop_mediated"])

CATE Estimation (Experimental)

⚠️ Experimental Feature: The CATE estimator API is under active development and may change.

Unlike estimate_te() which returns population-level averages, MatchingCATEEstimator predicts individual-level treatment effects:

from causalem._experimental import MatchingCATEEstimator
from causalem import load_data_lalonde

X, t, y = load_data_lalonde(raw=False)

# Initialize and fit the CATE estimator
est = MatchingCATEEstimator(
    niter=10,
    do_stacking=True,
    random_state=42
)
est.fit(X, t, y)

# Get individual treatment effects
individual_effects = est.effect()
print(f"Individual effects range: [{individual_effects.min():.2f}, {individual_effects.max():.2f}]")

# Population summaries
print(f"ATE on matched: {est.ate():.2f}")
print(f"ATT on matched: {est.att():.2f}")

# Identify high-benefit subgroups
import numpy as np
high_benefit_idx = np.where(individual_effects > np.percentile(individual_effects, 75))[0]
print(f"High-benefit group size: {len(high_benefit_idx)}")

For more details, see causalem/_experimental/README.md.


License

This project is licensed under the terms of the MIT License.

Release Notes

2.0.0

🚀 Performance Improvements:

  1. Major scalability enhancement for large datasets (100k+ observations)

    • Memory optimization: Matching algorithm now computes distances on-the-fly rather than pre-computing full n×n distance matrix
    • Memory reduction: O(n²) → O(n) complexity enables matching on datasets with 100k+ observations on standard hardware
    • Tested at scale: Successfully validated on 181k observations (real orthopedic registry data) and synthetic datasets up to 220k rows
    • Speed improvement: ~5-10× faster with caliper parameter (default 0.2 SD)
    • Backward compatibility: Results are bit-identical to previous implementation when using same random seed
    • Implementation: Modified internal matching functions to accept propensity score arrays directly instead of distance matrices
    • Legacy support: Pre-computed distance matrix input preserved for backward compatibility (not recommended for large n)
  2. Scalability validation

    • Previous limitation: MemoryError at ~50k observations (requires 20 GB RAM for distance matrix)
    • New capability: Handles 200k+ observations with <5 GB peak memory usage
    • Validated on ACIC2019 benchmark: Bit-identical results to original implementation
    • Note: Computational time still scales approximately O(n²) for matching phase, but memory bottleneck removed

⚠️ Breaking Changes:

  1. groups parameter renamed to _cv_groups in estimate_te(), estimate_te_multi(), and estimate_mediation()

    • Now marked as internal-only with underscore prefix
    • Migration: groups=my_groups_cv_groups=my_groups
  2. stochastic_match() function renamed to stochastic_matching()

    • Improves naming consistency across the API (both stochastic_matching and summarize_matching now use "matching")
    • Migration: Update all imports and function calls from stochastic_match to stochastic_matching
  3. Removed natural effects from mediation analysis

    • effect_type parameter removed from estimate_mediation()
    • Only interventional effects (IDE/IIE) are now returned
    • Natural effects were misleading when using flexible ML models that can learn treatment-mediator interactions
    • Migration: Remove effect_type="natural" arguments; code that relied on NDE/NIE keys will need to use IDE/IIE instead
  4. SD-based caliper and scale for estimate_mediation() (completing API harmonization with estimate_te())

    • Added std_matching parameter (default True) for SD-based interpretation of caliper/scale
    • Changed matching_caliper default: None0.2
    • Changed matching_scale default: 1.00.8
    • With std_matching=True, caliper and scale are interpreted as multiples of SD(logit propensity scores)
    • Benefits: Automatic adaptation to data spread, robust matching across diverse datasets
    • Migration: To preserve v1.x behavior, set std_matching=False, matching_scale=1.0, matching_caliper=None
  5. matching_is_stochastic parameter replaced with matching_method in estimate_mediation()

    • Old parameter was boolean; new parameter is string with three options: "stochastic", "deterministic", or None
    • Migration:
      • matching_is_stochastic=Truematching_method="stochastic"
      • matching_is_stochastic=False, niter=1matching_method=None, niter=1
      • matching_is_stochastic=False, niter>1matching_method="deterministic"
  6. New defaults in estimate_mediation() now use stochastic matching by default

    • niter default: 110
    • matching_method default: effectively None"stochastic"
    • Migration: To preserve v1.x behavior, explicitly set matching_method=None, niter=1
  7. SD-based caliper and scale specification in estimate_te() and estimate_te_multi()

    • Added std_matching parameter (default True) enabling standard deviation-based units
    • matching_caliper default: None0.2 (0.2 SD when std_matching=True)
    • matching_scale default: 1.00.8 (0.8 SD when std_matching=True)
    • When std_matching=True (default): caliper and scale interpreted as multiples of SD(logit propensity scores)
    • When std_matching=False: use absolute values (v1.x behavior)
    • Migration: To preserve v1.x behavior, set std_matching=False, matching_scale=1.0, matching_caliper=None
    • Rationale: 0.2 SD is recommended in matching literature (Austin 2011, Rosenbaum & Rubin 1985) and adapts to data spread
  8. Changed default return format for dataset loaders

    • load_data_tof() and load_data_lalonde() now default to raw=False
    • These functions now return (X, t, y) or (X, t, m, y) arrays by default instead of DataFrames
    • Migration: To get DataFrame output (v1.x behavior), explicitly pass raw=True
    • Rationale: Array output aligns with how these functions are predominantly used in practice

New Features:

  • SD-based matching parameters: std_matching parameter enables adaptive caliper and scale based on propensity score spread
  • Added estimand parameter to estimate_mediation() supporting "ATM" (default) and "ATT"
  • Added n_splits_propensity parameter to estimate_mediation() (default: 5) for propensity model cross-fitting
  • Updated MatchingCATEEstimator (experimental) to support std_matching parameter with new defaults

Bug Fixes:

  • Fixed prob_clip_eps parameter not being passed through in survival pathway
  • Improved parameter handling in _estimate_te_survival_single_iter()

Documentation:

  • Clarified that meta-learner is cross-fitted using n_splits_outcome folds in addition to base models

1.3.0

New Experimental Feature: CATE (Conditional Average Treatment Effect) Estimation

  • Added MatchingCATEEstimator class in causalem._experimental module for individual-level treatment effect prediction
  • Provides scikit-learn style fit()/effect() API for learning and predicting heterogeneous treatment effects
  • Key capabilities:
    • Individual-level treatment effect predictions (not just population averages)
    • Prediction on new/unseen data
    • ATM and ATT estimands supported
    • Stochastic and deterministic matching
    • Ensemble stacking with meta-learners
    • Compatible with heterogeneous base learners
  • Current scope (binary treatment, non-survival outcomes):
    • ✓ Binary and continuous outcomes
    • ✓ Stochastic/deterministic matching
    • ✓ Stacking and no-stacking modes
    • ✓ ATM/ATT estimands
    • Future: Multi-arm treatment, survival outcomes, bootstrap CIs
  • Validation: Comprehensive test suite verifying parity with estimate_te()
  • Documentation: Detailed design documentation in causalem/_experimental/README.md
  • Status: ⚠️ Experimental API - may change in future releases

API Example:

from causalem._experimental import MatchingCATEEstimator

est = MatchingCATEEstimator(niter=10, do_stacking=True, random_state=42)
est.fit(X, t, y)

# Individual effects
effects = est.effect()

# Population summaries
ate = est.ate()
att = est.att()

1.2.0

New Feature: Covariate Inclusion in Stacking Meta-Learner

  • Added include_covariates_in_stacking parameter to estimate_te() to enable including covariates in the meta-learner stage
  • When True, covariates are included alongside base learner predictions in the meta-learner design matrix, allowing the meta-learner to learn non-linear combinations of predictions conditional on covariates
  • Implemented across all pathways: binary, multi-arm, and survival outcomes
  • For stacking mode with do_stacking=True, both base predictions and original covariates are passed to the meta-learner
  • Defaults to False to preserve backward compatibility
  • Warning issued if include_covariates_in_stacking=True but do_stacking=False (parameter has no effect without stacking)
  • Comprehensive test coverage: 9 new tests covering all outcome types and edge cases

Documentation Enhancement: Heterogeneous Ensembles

  • Improved documentation for the existing heterogeneous learner feature in model_outcome parameter
  • Previously feature-complete but undocumented: model_outcome now clearly documents support for:
    • List/tuple of estimators: Mix different model types across iterations (e.g., Random Forest + Gradient Boosting + Linear models)
    • Generator/iterator: Dynamically yield different models for each iteration
    • Single estimator: Homogeneous ensemble (backward compatible)
  • Added practical examples showing heterogeneous ensemble usage with lists and generators
  • Documented benefits: improved robustness by combining models with different inductive biases
  • Comprehensive test suite added: 22 tests (675 lines) in tests/test_heterogeneous_learners.py covering:
    • All input types (list, tuple, generator)
    • All outcome types (continuous, binary, survival)
    • Multi-arm treatments
    • Error handling (insufficient models, exhausted generators)
    • Integration with all features (bootstrap, stacking, covariates, ATT, stochastic matching)
    • Reproducibility and comparisons

Bug Fixes:

  • Fixed multi-arm stacking to correctly use encoder categories when constructing counterfactual design matrices

API Enhancements:

# Meta-learner uses only base predictions (default)
result = estimate_te(X, t, y, do_stacking=True, include_covariates_in_stacking=False)

# Meta-learner uses both base predictions and covariates
result = estimate_te(X, t, y, do_stacking=True, include_covariates_in_stacking=True)

# Heterogeneous ensemble with different model types
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression
outcome_models = [
    RandomForestRegressor(n_estimators=100),
    GradientBoostingRegressor(n_estimators=100),
    LinearRegression()
]
result = estimate_te(X, t, y, model_outcome=outcome_models, niter=3)

1.1.0

New Feature: Estimand Parameter (ATT vs ATM)

  • Added estimand parameter to estimate_te() and estimate_te_multi() functions
  • Supports two estimands:
    • 'ATM' (default): Average Treatment Effect on Matched sample - averages over all units appearing in matched sets (preserves backward compatibility)
    • 'ATT': Average Treatment Effect on Treated (common support) - averages over treated/ref_group units that were successfully matched
  • Implemented across all pathways: binary, multi-arm, and survival outcomes
  • For multi-arm with estimand='ATT', ref_group parameter specifies which arm is the "treated" group
  • ATT computes effects on matched treated units only (not all treated), following standard matching literature practice of estimating on the common support
  • Comprehensive test coverage: 14 new tests covering all outcome types and pathways

API Enhancement:

# Target effect on matched sample (default)
result = estimate_te(X, t, y, estimand='ATM')

# Target effect on treated population
result = estimate_te(X, t, y, estimand='ATT')

1.0.1

  • Removed the R section of README.md since it has not been released yet.
  • Added release notes for version 1.0.0.

1.0.0

  • Removed binarize_outcome parameter from load_data_lalonde and load_data_tof.
  • Absorbed load_data_tof_with_mediator into load_data_tof.

0.7.0

  • Added mediation analysis functionality with estimate_mediation function for interventional mediation effects using plug-in G-computation.
  • Supports binary treatment with binary/continuous mediators and continuous outcomes.
  • Features bootstrap confidence intervals and optional integration with stochastic matching for improved robustness.
  • Estimates total effect (TE), interventional direct effect (IDE), and interventional indirect effect (IIE).

0.6.2

  • Exposed a new n_mc argument in estimate_te for specifying Monte‑Carlo draws per matched unit in survival analyses, replacing the previously fixed single draw.
  • Clarified treatment‑effect estimands for stacking vs. no‑stacking modes, noting that stacked results are appearance‑weighted across the matched union.
  • Documented appearance‑weighted meta‑learning and matched‑union survival contrasts.

0.6.1

  • Corrected the version number in pyproject.toml file.

0.6.0

  • Improved consistency of return data structure when do_stacking=False in multi-arm TE estimation.

0.5.4

  • Added github action for publishing to PyPI

0.5.3

  • First public release

0.5.1

  • Edits to readme
  • Added github action for publishing to (test) PyPI

0.5.0

  • First test release

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

causalem-2.0.0.tar.gz (164.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

causalem-2.0.0-py3-none-any.whl (127.5 kB view details)

Uploaded Python 3

File details

Details for the file causalem-2.0.0.tar.gz.

File metadata

  • Download URL: causalem-2.0.0.tar.gz
  • Upload date:
  • Size: 164.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for causalem-2.0.0.tar.gz
Algorithm Hash digest
SHA256 970c8075a959b21ee59fa23749aae1b4ad77697ae9a732eb4436bcb5ce27ece9
MD5 6ecf7e2e5d592f0f6e2bf9d6b5c780e2
BLAKE2b-256 dee98f977db02dfa3802b6e6261774eefb11a17489c2653c56af6c9511debfcd

See more details on using hashes here.

File details

Details for the file causalem-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: causalem-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 127.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for causalem-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fb9dd747d1cd69beb6499e61c1e4548d2ffa2b5d7bad46b5df86c9eae626a907
MD5 537eb2d29d8dd93f10d649f8ef51c064
BLAKE2b-256 bbb6f14995c38d9d2dd4e97da82fefaf53b09b5330c3b249a759edb2ae494c01

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page