Fast rolling entry matching for staggered adoption studies using polars + numpy
Project description
pyrollmatch
Matching and weighting for staggered treatment adoption studies in Python, built on polars and numpy.
Installation
pip install pyrollmatch
From source:
git clone https://github.com/AlanHuang99/pyrollmatch.git
cd pyrollmatch
pip install -e ".[dev]"
Quick Start
import polars as pl
from pyrollmatch import rollmatch
data = pl.read_parquet("panel_data.parquet")
result = rollmatch(
data,
treat="treat",
tm="time",
entry="entry_time",
id="unit_id",
covariates=["x1", "x2", "x3"],
ps_caliper=0.2,
num_matches=1,
)
result.balance # covariate balance table (SMDs)
result.matched_data # match pairs: [tm, treat_id, control_id, difference]
result.weights # unit-level weights: [id, weight]
Methods
Nearest-Neighbor Matching (default)
Greedy nearest-neighbor matching on propensity scores or pairwise covariate distances. Supports PS calipers, per-variable calipers, matching order, and three replacement modes.
# Propensity score matching (logistic regression)
result = rollmatch(data, ..., ps_caliper=0.2, num_matches=1)
# Mahalanobis distance matching
result = rollmatch(data, ..., model_type="mahalanobis")
# Mahalanobis matching with PS caliper (MatchIt mahvars pattern)
result = rollmatch(data, ..., ps_caliper=0.25, mahvars=["x1", "x2"])
Entropy Balancing
Direct covariate balance via convex optimization (Hainmueller 2012). Each entry cohort weights the full control pool independently (stacked design).
result = rollmatch(data, ..., method="ebal", moment=1)
result.weights # unit weights
result.weighted_data # per-cohort weights [tm, id, weight]
Custom Methods
User-defined per-period weighting functions:
def my_method(treated_data, control_data, covariates, id, **kwargs):
# Return pl.DataFrame with columns [id, weight]
...
result = rollmatch(data, ..., method=my_method)
Scoring Models
11 model types via the model_type parameter:
Propensity score models
Fit a classifier, match on |score_i - score_j|.
model_type |
Description |
|---|---|
"logistic" |
Standard logistic regression (default) |
"probit" |
Probit model (inverse normal CDF) |
"gbm" |
Gradient boosting (HistGradientBoostingClassifier) |
"rf" |
Random forest |
"lasso" |
L1-regularized logistic |
"ridge" |
L2-regularized logistic |
"elasticnet" |
L1+L2 regularized logistic |
Distance-based models
Pairwise covariate distances. No propensity model fitted.
model_type |
Description |
|---|---|
"mahalanobis" |
Mahalanobis distance (pooled within-group covariance, MatchIt convention) |
"scaled_euclidean" |
Euclidean on covariates standardized by pooled within-group SD |
"robust_mahalanobis" |
Rank-based Mahalanobis (Rosenbaum 2010). Robust to outliers |
"euclidean" |
Raw Euclidean distance |
Matching Parameters
Caliper
# Propensity score caliper: ps_caliper * pooled_SD
result = rollmatch(data, ..., ps_caliper=0.2)
# Per-variable calipers (in SD units by default)
result = rollmatch(data, ..., caliper={"age": 0.5, "income": 0.3})
# Per-variable calipers in raw units
result = rollmatch(data, ..., caliper={"age": 5}, std_caliper=False)
ps_caliper_std controls how the pooled SD is computed: "average" (default), "weighted", or "none" (raw units).
Replacement Modes
Controls whether control units can be reused across matches.
replacement= |
Within period | Across periods | Use case |
|---|---|---|---|
"unrestricted" |
Reuse freely | Reuse freely | Maximize match rate |
"cross_cohort" |
No reuse | Reuse allowed | Default. Balanced within-period |
"global_no" |
No reuse | No reuse | Strictest. Each control used at most once |
Matching Order (m_order)
Controls which treated units are matched first. Matters when replacement is constrained.
m_order= |
Behavior |
|---|---|
"largest" |
Highest PS first (default for PS models). Hard-to-match units get first pick |
"smallest" |
Lowest PS first |
"random" |
Random order |
"data" |
Original data order (default for distance models) |
The mahvars Pattern
Match on Mahalanobis distance of specific covariates while using a propensity score caliper to restrict the pool. Follows the MatchIt mahvars convention.
result = rollmatch(
data, ...,
covariates=["x1", "x2", "x3"], # PS estimated on all covariates
ps_caliper=0.25, # PS caliper for pool restriction
mahvars=["x1", "x2"], # Mahalanobis matching on these
)
Balance Diagnostics
Post-Matching Balance
from pyrollmatch import balance_test, equivalence_test
# SMD + t-test + variance ratio + KS test
diag = balance_test(scored_data, result.matched_data,
"treat", "unit_id", "time", covariates)
# TOST equivalence test (Hartman & Hidalgo 2018)
equiv = equivalence_test(scored_data, result.matched_data,
"treat", "unit_id", "time", covariates)
Per-Period Balance
Pooled SMD can mask within-cohort imbalance. Check each entry cohort:
from pyrollmatch import balance_by_period
agg, detail = balance_by_period(
scored_data, result.matched_data,
"treat", "unit_id", "time", covariates,
)
# agg: covariate × {wtd_mean_smd, median_abs_smd, max_abs_smd}
# detail: period × covariate × {n_treated, n_controls, smd}
Weighted Diagnostics (for ebal/custom)
from pyrollmatch import balance_test_weighted, equivalence_test_weighted
diag = balance_test_weighted(reduced_data, result.weights,
"treat", "unit_id", covariates)
Data Format
Input: polars.DataFrame in long panel format (one row per unit per time period).
| Column | Type | Description |
|---|---|---|
id |
int/str | Unit identifier |
tm |
int | Time period (integer, monotonically increasing) |
treat |
int (0/1) | Time-invariant treatment group indicator. 1 = eventually treated, 0 = never treated |
entry |
int | Treatment onset period for treated units. For controls: any value > max(tm) or null |
| covariates | float | Matching variables |
unit_id | time | treat | entry_time | x1 | x2
--------|------|-------|------------|------|-----
1 | 1 | 1 | 5 | 2.3 | 1.1 <- treated, enters period 5
1 | 2 | 1 | 5 | 2.5 | 1.0
...
101 | 1 | 0 | 99 | 1.8 | 0.9 <- control (entry=99 sentinel)
101 | 2 | 0 | 99 | 1.9 | 1.0
Pipeline
For advanced use, each step is independently callable:
from pyrollmatch import reduce_data, score_data, compute_balance
# 1. Reduce: select baseline covariates for each entry cohort
reduced = reduce_data(data, "treat", "time", "entry_time", "unit_id", lookback=1)
# 2. Score: fit model, compute scores/distances
scored = score_data(reduced, ["x1", "x2", "x3"], "treat", model_type="logistic")
scored.data # DataFrame with "score" column
scored.model # fitted sklearn classifier
# 3. Match: via rollmatch() or match_all_periods() directly
# 4. Balance: assess covariate balance
balance = compute_balance(scored.data, matches, "treat", "unit_id", "time", covariates)
API Reference
- REFERENCE.md — parameter tables, return types, migration guide
- API docs — searchable HTML reference (auto-generated from docstrings)
Reproducibility
rollmatch produces bit-identical matches when called repeatedly with the same input data and the same configuration, on the same machine.
- All distance-based models (
mahalanobis,scaled_euclidean,robust_mahalanobis,euclidean) are deterministic by construction. - All propensity score models (
logistic,probit,lasso,ridge,elasticnet,gbm,rf) use a fixed internal seed (42), andrfruns withn_jobs=1to avoid joblib's last-bit non-determinism across thread counts. - BLAS thread count (
OMP_NUM_THREADS,OPENBLAS_NUM_THREADS, etc.) does not affect matching results at any observable level. Pair IDs and distance values are bit-identical across thread counts on the same machine.
To override the internal seed for sensitivity analyses, pass random_state:
# Bootstrap sensitivity analysis over 100 different RF seeds
sigs = []
for s in range(100):
r = rollmatch(data, ..., model_type="rf", random_state=s)
sigs.append(r.balance["matched_smd"].abs().max())
random_state is also threaded through to m_order="random" for reproducible random matching order.
Across different machines, pair IDs are robust but last-bit floating-point drift in BLAS routines (different OpenBLAS/MKL builds, different SIMD instruction widths) can in principle produce different difference values. In practice, this only affects matching decisions when two controls are tied to within ~1e-12 Mahalanobis distance of a treated unit — essentially impossible unless your covariates contain exact duplicates. The regression tests in tests/test_reproducibility.py guard same-machine reproducibility.
Testing
uv run pytest tests/ # 239 tests
uv run pytest tests/ -k stress # stress/scale tests
Tests include synthetic data, the Lalonde dataset, and a staggered panel fixture.
Acknowledgements
Inspired by the rollmatch R package by RTI International (Witman et al. 2018). Distance metrics and matching conventions follow MatchIt (Imai, King, Stuart 2011).
References
- Witman, A., et al. (2018). "Comparison Group Selection in the Presence of Rolling Entry." Health Services Research, 54(1), 262-270.
- Hainmueller, J. (2012). "Entropy Balancing for Causal Effects." Political Analysis, 20(1), 25-46.
- Imai, K., King, G., Stuart, E. (2011). MatchIt: Nonparametric Preprocessing for Parametric Causal Inference. Journal of Statistical Software, 42(8).
- Rosenbaum, P. (2010). Design of Observational Studies, ch. 8.
- Hartman, E. & Hidalgo, F. D. (2018). "An Equivalence Approach to Balance and Placebo Tests." American Journal of Political Science, 62(4), 1000-1013.
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyrollmatch-0.1.6.tar.gz.
File metadata
- Download URL: pyrollmatch-0.1.6.tar.gz
- Upload date:
- Size: 3.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bbc0fbe491d05f649a8c573fa7f66f6ca3f667f58b414996867d2e46602442ac
|
|
| MD5 |
c5bd9b0c0394d216fdc64e388acac02f
|
|
| BLAKE2b-256 |
66d0e4b69d60c96e732feff257b3a2ef2b8e6968f815dd9b0b154dbb1ac184cd
|
File details
Details for the file pyrollmatch-0.1.6-py3-none-any.whl.
File metadata
- Download URL: pyrollmatch-0.1.6-py3-none-any.whl
- Upload date:
- Size: 37.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e401a7dece6aae87a1c028ff62f0c564e782c51e89dce73302a2ca5976a89a18
|
|
| MD5 |
9674b0ae3405136ede4578c77b316544
|
|
| BLAKE2b-256 |
b5334eaeb2f0afb798b3e8fed3653b7857176501ef555b77ed13dc71b584c832
|