Skip to main content

Causal Inference using Ensemble Matching

Project description

CausalEM – Ensemble Matching for Causal Inference

CausalEM is a toolbox for multi-arm treatment‑effect estimation using stochastic matching and a stacked ensemble of heterogeneous ML models. It supports continuous, binary, and survival outcomes.


Key Features

  1. Stochastic nearest-neighbor (NN) matching -> Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching.
  2. G-computation using two-staged, stacked ensemble of hetrogeneous learners -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
  3. Support for multi-arm treatments -> Improved multi-arm ESS via stochastic matching.
  4. Support for survival outcomes -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
  5. Bootstrapped confidence interval (CI) estimation -> Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop.
  6. Compatible with scikit-learn -> Maximum flexibility in using ML models by providing access to scikit-learn (and scikit-survival for survival) for propensity-score, outcome and meta-learner stages.
  7. Full reproducibility of results --> Careful implementation of random number generation (RNG) seeding, including in scikit-learn models.

API

Function Brief description
estimate_te Main pipeline – ensemble matching + meta‑learner
StochasticMatcher 1:1 nearest‑neighbor matcher (deterministic ↔ stochastic)
summarize_matching Diagnostics: ESS, ASMD, variance ratios, overlap plots
load_data_lalonde Copy of Lalonde job‑training dataset
load_data_tof Simulated TOF dataset (survival or binary outcome)

⚙️ Installation

pip install causalem

Optional dev extras:

pip install "causalem[dev]"

Minimum Python 3.9. Tested on macOS and Windows.


Package Vignette

For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.


🚀 Quick Start

Two-arm Analysis

Load the necessary packages:

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from causalem import (
  estimate_te,
  load_data_tof,
  stochastic_match,
  summarize_matching
)

Load the ToF data with two treatment levels and binarized outcome:

X, t, y = load_data_tof(
  raw = False,
  treat_levels = ['PrP', 'SPS'],
  binarize_outcome=True,
)

Stochastic matching using propensity scores:

lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))

cluster = stochastic_match(
    treatment=t,
    score=logit_score,
    nsmp=10,
    scale=1.0,
    random_state=0,
)

diag = summarize_matching(
  cluster, X,
  treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)

TE estimation:

res = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Two-arm TE:", res["te"])

Multi-arm Analysis

Load data for multi-arm analysis:

df = load_data_tof(
  raw = True,
  binarize_outcome=True,
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()

Constructing propensity scores using multinomial logistic regression:

lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))

Multi-arm stochastic matching:

cluster_multi = stochastic_match(
    treatment=t_all,
    score=logit_multi,
    nsmp=5,
    scale=1.0,
    ref_group=ref,
    random_state=0,
)
diag_multi = summarize_matching(
    cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])

Multi-arm TE estimation:

res_multi = estimate_te(
    X_all,
    t_all,
    y_all,
    outcome_type="binary",
    ref_group=ref,
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])

Confidence-Interval Calculation

Adding bootstrap CI to the two-arm analysis:

res_boot = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    nboot=200,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
    random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])

Heterogeneous Ensemble

learners = [
    LogisticRegression(max_iter=1000),
    RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    model_outcome=learners,
    niter=len(learners),
    do_stacking=True,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])

Stacking vs No-Stacking

# No-stacking: average per-iteration effects without appearance weights
res_ns = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    do_stacking=False,
    random_state_master=0,
)

# Stacking: meta-learner fit with appearance weights over the matched union
res_stack = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    do_stacking=True,
    random_state_master=0,
)

TE Estimation for Survival Outcomes

X_surv, t_surv, y_surv = load_data_tof(
  raw=False
  , treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
    X_surv,
    t_surv,
    y_surv,
    outcome_type="survival",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=0,
)
print("Survival HR:", res_surv["te"])

License

This project is licensed under the terms of the MIT License.

Release Notes

0.6.2

  • Exposed a new n_mc argument in estimate_te for specifying Monte‑Carlo draws per matched unit in survival analyses, replacing the previously fixed single draw.
  • Clarified treatment‑effect estimands for stacking vs. no‑stacking modes, noting that stacked results are appearance‑weighted across the matched union.
  • Documented appearance‑weighted meta‑learning and matched‑union survival contrasts.

0.6.1

  • Corrected the version number in pyproject.toml file.

0.6.0

  • Improved consistency of return data structure when do_stacking=False in multi-arm TE estimation.

0.5.4

  • Added github action for publishing to PyPI

0.5.3

  • First public release

0.5.1

  • Edits to readme
  • Added github action for publishing to (test) PyPI

0.5.0

  • First test release

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

causalem-0.6.2.tar.gz (81.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

causalem-0.6.2-py3-none-any.whl (72.3 kB view details)

Uploaded Python 3

File details

Details for the file causalem-0.6.2.tar.gz.

File metadata

  • Download URL: causalem-0.6.2.tar.gz
  • Upload date:
  • Size: 81.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.13

File hashes

Hashes for causalem-0.6.2.tar.gz
Algorithm Hash digest
SHA256 ee767460a0e22f588a1748a9ccfa85e58b23622fd25db025a6379385f4669c30
MD5 a54d92139903ad2ffdd5b734c9327d49
BLAKE2b-256 c32f1dc2278f314bf0c5fbd8f9ae12e884c480e37b19eda5553f47e7aef0d802

See more details on using hashes here.

File details

Details for the file causalem-0.6.2-py3-none-any.whl.

File metadata

  • Download URL: causalem-0.6.2-py3-none-any.whl
  • Upload date:
  • Size: 72.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.13

File hashes

Hashes for causalem-0.6.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a58ea9772d121137d9d3041fe6b1732d6b6480af803389774be49a5e2d8c03df
MD5 d0224d91b2ca41e549b672553181ad23
BLAKE2b-256 2294e251d41963f45d0844ad1c49fa86d0ab49465ad8be29290f3cb9ca0b7703

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page