Skip to main content

Causal Inference using Ensemble Matching

Project description

CausalEM – Ensemble Matching for Causal Inference

PyPI version License: MIT

CausalEM is an ensemble‑based toolbox for multi-arm treatment‑effect estimation using stochastic matching, with support for continuous, binary, and right-censored time-to-event (survival) outcomes.


Key Features

  1. Stochastic adaptation of nearest-neighbor (NN) matching -> Larger effective sample size (ESS) and improved TE estimation accuracy vs. standard (deterministic) NN matching.
  2. G-computation using two-staged, stacked ensemble of hetrogeneous learners -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
  3. Support for multi-arm treatments -> Stochastic matching in CausalEM can be especially helpful in multi-arm scenarios for improving ESS.
  4. Support for survival outcomes -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
  5. Bootstrapped confidence interval (CI) estimation -> Honest estimates of CI by including entire matching + TE estimation pipeline in bootstrap loop.
  6. Compatible with scikit-learn -> Maximum flexibility in using machine learning by providing access to scikit-learn models for propensity-score, outcome and meta-learner stages (scikit-survival for survival outcomes).
  7. Full reproducibility of results --> Careful implementation of seeding for random number generation (RNG), including in scikit-learn models.

API

Function Brief description
estimate_te Main pipeline – ensemble matching + meta‑learner
StochasticMatcher 1:1 nearest‑neighbor matcher (deterministic ↔ stochastic)
summarize_matching Diagnostics: ESS, ASMD, variance ratios, overlap plots
load_data_lalonde Copy of Lalonde job‑training dataset
load_data_tof Simulated TOF dataset (survival or binary outcome)

⚙️ Installation

pip install causalem

Optional dev extras:

pip install "causalem[dev]"

Minimum Python 3.9. Tested on macOS and Windows.


Package Vignette

For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.


🚀 Quick Start

Two-arm Analysis

Load the necessary packages:

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from causalem import (
  estimate_te,
  load_data_tof,
  stochastic_match,
  summarize_matching
)

Load the ToF data with two treatment levels and binarized outcome:

X, t, y = load_data_tof(
  raw = False,
  treat_levels = ['PrP', 'SPS'],
  binarize_outcome=True,
)

Stochastic matching using propensity scores:

lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))

cluster = stochastic_match(
    treatment=t,
    score=logit_score,
    nsmp=10,
    scale=1.0,
    random_state=0,
)

diag = summarize_matching(
  cluster, X,
  treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)

TE estimation:

res = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Two-arm TE:", res["te"])

Multi-arm Analysis

Load data for multi-arm analysis:

df = load_data_tof(
  raw = True,
  binarize_outcome=True,
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()

Constructing propensity scores using multinomial logistic regression:

lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))

Multi-arm stochastic matching:

cluster_multi = stochastic_match(
    treatment=t_all,
    score=logit_multi,
    nsmp=5,
    scale=1.0,
    ref_group=ref,
    random_state=0,
)
diag_multi = summarize_matching(
    cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])

Multi-arm TE estimation:

res_multi = estimate_te(
    X_all,
    t_all,
    y_all,
    outcome_type="binary",
    ref_group=ref,
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])

Confidence-Interval Calculation

Adding bootstrap CI to the two-arm analysis:

res_boot = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    nboot=200,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
    random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])

Heterogeneous Ensemble

learners = [
    LogisticRegression(max_iter=1000),
    RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    model_outcome=learners,
    niter=len(learners),
    do_stacking=True,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])

TE Estimation for Survival Outcomes

X_surv, t_surv, y_surv = load_data_tof(
  raw=False
  , treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
    X_surv,
    t_surv,
    y_surv,
    outcome_type="survival",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=0,
)
print("Survival HR:", res_surv["te"])

License

This project is licensed under the terms of the MIT License – see the LICENSE file.

Release Notes

0.5.0

  • First public release

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

causalem-0.5.0.tar.gz (76.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

causalem-0.5.0-py3-none-any.whl (69.2 kB view details)

Uploaded Python 3

File details

Details for the file causalem-0.5.0.tar.gz.

File metadata

  • Download URL: causalem-0.5.0.tar.gz
  • Upload date:
  • Size: 76.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.4

File hashes

Hashes for causalem-0.5.0.tar.gz
Algorithm Hash digest
SHA256 663e4f287a6bc6db3b1ab52c9e8d7a36987d777ab13ad830ac0a0d218cee9829
MD5 a3081921146db705f9ba352441ed8f0e
BLAKE2b-256 34a9159335e08422e7bbe00023d348542dc8ad984573e43469ea5680308cf0a7

See more details on using hashes here.

File details

Details for the file causalem-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: causalem-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 69.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.4

File hashes

Hashes for causalem-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2886541efeb0647b8b4751f964bdcb2d34782cd487d6821508d93b87221abb3d
MD5 a27676548f865611c71c003b5040f9f2
BLAKE2b-256 a96da4357bdb6bbec7e62cccb76f648143281ba6cd1076557c9b0dadcb61e165

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page