Skip to main content

Causal Inference using Ensemble Matching

Project description

CausalEM – Ensemble Matching for Causal Inference

CausalEM is a toolbox for multi-arm treatment‑effect estimation and mediation analysis using stochastic matching and a stacked ensemble of heterogeneous ML models. It supports continuous, binary, and survival outcomes.

Table of Contents


Key Features

Feature Impact
Stochastic nearest-neighbor (NN) matching Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching
G-computation using two-staged, stacked ensemble of heterogeneous learners Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML
Support for multi-arm treatments Improved multi-arm ESS via stochastic matching
Mediation analysis Plug-in G-computation for interventional mediation effects (IDE/IIE) with binary treatment and binary/continuous mediators and outcomes, supporting bootstrap confidence intervals and optional stochastic matching
Support for survival outcomes Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data
Bootstrapped confidence interval (CI) estimation Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop
Compatible with scikit-learn Maximum flexibility in using ML models by providing access to scikit-learn (and scikit-survival for survival) for propensity-score, outcome and meta-learner stages
Full reproducibility of results Careful implementation of random number generation (RNG) seeding, including in scikit-learn models
Available in Python and R Identical function-centric API in both languages using reticulate; combined with RNG management, leads to identical, reproducible results across platforms

API

Function Brief description
estimate_te Main pipeline – ensemble matching + meta‑learner
estimate_mediation Mediation analysis with plug-in G-computation
StochasticMatcher 1:1 nearest‑neighbor matcher (deterministic ↔ stochastic)
summarize_matching Diagnostics: ESS, ASMD, variance ratios, overlap plots
load_data_lalonde Standard Lalonde job‑training dataset (two-arm, continuous outcome)
load_data_tof New simulated Tetralogy of Fallot (ToF) dataset (two-arm or three-arm, survival/binary/continuous outcome)

⚙️ Installation

pip install causalem

Optional dev extras:

pip install "causalem[dev]"

Minimum Python 3.9. Tested on macOS and Windows.


Package Vignette

For a more detailed introduction to CausalEM, including the underlying math, see the package vignette [insert link later], available on arXiv.


🚀 Quick Start

Two-arm Analysis

Load the necessary packages:

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from causalem import (
  estimate_te,
  load_data_tof,
  stochastic_match,
  summarize_matching
)

Load the ToF data with two treatment levels and binarized outcome:

X, t, y = load_data_tof(
  raw = False,
  treat_levels = ['PrP', 'SPS'],
  outcome_type="binary",
)

Stochastic matching using propensity scores:

lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))

cluster = stochastic_match(
    treatment=t,
    score=logit_score,
    nsmp=10,
    scale=1.0,
    random_state=0,
)

diag = summarize_matching(
  cluster, X,
  treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)

TE estimation (includes stochastic matching as the first step, followed by outcome modeling):

res = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Two-arm TE:", res["te"])

Multi-arm Analysis

Load data for multi-arm analysis:

df = load_data_tof(
  raw = True,
  outcome_type="binary",
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()

Constructing propensity scores using multinomial logistic regression:

lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))

Multi-arm stochastic matching:

cluster_multi = stochastic_match(
    treatment=t_all,
    score=logit_multi,
    nsmp=5,
    scale=1.0,
    ref_group=ref,
    random_state=0,
)
diag_multi = summarize_matching(
    cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])  # dict of counts by group

Multi-arm TE estimation:

res_multi = estimate_te(
    X_all,
    t_all,
    y_all,
    outcome_type="binary",
    ref_group=ref,
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])

Confidence-Interval Calculation

Adding bootstrap CI to the two-arm analysis:

res_boot = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    nboot=200,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
    random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])

Heterogeneous Ensemble

learners = [
    LogisticRegression(max_iter=1000),
    RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    model_outcome=learners,
    niter=len(learners),
    do_stacking=True,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])

Stacking vs No-Stacking

# No-stacking: average per-iteration effects without appearance weights
res_ns = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    do_stacking=False,
    random_state_master=0,
)

# Stacking: meta-learner fit with appearance weights over the matched union
res_stack = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    do_stacking=True,
    random_state_master=0,
)

TE Estimation for Survival Outcomes

X_surv, t_surv, y_surv = load_data_tof(
  raw=False
  , treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
    X_surv,
    t_surv,
    y_surv,
    outcome_type="survival",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=0,
)
print("Survival HR:", res_surv["te"])

Mediation Analysis

# Load ToF data with mediation structure
from causalem.datasets import load_data_tof
from causalem.mediation import estimate_mediation

# Load ToF data: binary treatment (PrP vs SPS), continuous mediator (op_time), binary outcome
X, A, M, Y = load_data_tof(
    raw=False,
    treat_levels=['PrP', 'SPS'],  # Binary treatment comparison
    outcome_type="binary",        # Binary outcome for simpler interpretation  
    include_mediator=True         # Include mediator variable (op_time)
)

# Estimate mediation effects
result = estimate_mediation(X, A, M, Y, random_state_master=42)

print("Total Effect (TE):", result["te"])
print("Interventional Direct Effect (IDE):", result["ide"]) 
print("Interventional Indirect Effect (IIE):", result["iie"])
print("Proportion Mediated:", result["prop_mediated"])

CausalEM in R

CausalEM is also available as an R package that wraps the Python implementation using reticulate. This provides an identical API in R with the same reproducible results.

Installation (R)

First, ensure you have Python and the causalem package installed:

pip install "causalem>=0.7.0"

Then install the R package (note: currently available via GitHub, not yet on CRAN):

# Install from GitHub (if available)
# devtools::install_github("asmahani/matchml", subdir = "R/CausalEM")

# Or install locally from the repository
# install.packages("path/to/matchml/R/CausalEM", repos = NULL, type = "source")

Quick Start (R)

Basic two-arm analysis:

library(causalem)

# Load sample data
df <- load_data_lalonde(raw = FALSE)
res <- estimate_te(df$X, df$t, df$y, 
                   outcome_type = "continuous",
                   niter = 5,
                   random_state_master = 123)
print(res$te)

Multi-arm analysis:

# Load 3-arm survival data
surv_df <- load_data_tof(treat_levels = c("PrP", "RVOTd", "SPS"))
X_surv <- as.matrix(surv_df[c("age", "zscore")])
t_surv <- surv_df$treatment
y_surv <- as.matrix(surv_df[c("time", "status")])

# Multi-arm analysis
res <- estimate_te(X_surv, t_surv, y_surv,
                   outcome_type = "survival",
                   ref_group = "PrP",
                   niter = 5,
                   random_state_master = 456)

print(res$pairwise)

Mediation analysis:

# Load dataset with mediator
data <- load_data_tof_with_mediator(raw = FALSE, 
                                   treat_levels = c("PrP", "SPS"),
                                   outcome_type = "continuous")

# Estimate mediation effects
res <- estimate_mediation(data$X, data$t, data$m, data$y,
                         n_mc_mediator = 100,
                         random_state_master = 123)

print(res$te)   # Total effect
print(res$ide)  # Interventional direct effect  
print(res$iie)  # Interventional indirect effect

For more details on the R package, see R/CausalEM/README.md.


License

This project is licensed under the terms of the MIT License.

Release Notes

0.7.0

  • Added mediation analysis functionality with estimate_mediation function for interventional mediation effects using plug-in G-computation.
  • Supports binary treatment with binary/continuous mediators and continuous outcomes.
  • Features bootstrap confidence intervals and optional integration with stochastic matching for improved robustness.
  • Estimates total effect (TE), interventional direct effect (IDE), and interventional indirect effect (IIE).

0.6.2

  • Exposed a new n_mc argument in estimate_te for specifying Monte‑Carlo draws per matched unit in survival analyses, replacing the previously fixed single draw.
  • Clarified treatment‑effect estimands for stacking vs. no‑stacking modes, noting that stacked results are appearance‑weighted across the matched union.
  • Documented appearance‑weighted meta‑learning and matched‑union survival contrasts.

0.6.1

  • Corrected the version number in pyproject.toml file.

0.6.0

  • Improved consistency of return data structure when do_stacking=False in multi-arm TE estimation.

0.5.4

  • Added github action for publishing to PyPI

0.5.3

  • First public release

0.5.1

  • Edits to readme
  • Added github action for publishing to (test) PyPI

0.5.0

  • First test release

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

causalem-1.0.0.tar.gz (171.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

causalem-1.0.0-py3-none-any.whl (147.3 kB view details)

Uploaded Python 3

File details

Details for the file causalem-1.0.0.tar.gz.

File metadata

  • Download URL: causalem-1.0.0.tar.gz
  • Upload date:
  • Size: 171.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for causalem-1.0.0.tar.gz
Algorithm Hash digest
SHA256 661db48e45eee61de7faab9f7f0eb112b05268eb03708efc69607dede7577235
MD5 9584b218d8b599822b77e1dd5521f26a
BLAKE2b-256 a15eb75c8988b92711e0a88f5ff8978e517881dd38e1554e272db1c3f1b96001

See more details on using hashes here.

File details

Details for the file causalem-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: causalem-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 147.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for causalem-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1054c48ddb9107f0f5488a266d7fafc94bb3fb90527d3f24c1b21f87c9da8cca
MD5 39ace70c68999168751f4c869c655000
BLAKE2b-256 0b16e05d788a911a433f6c819233a96323da227db524d53ef5815e905552cf28

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page