Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
Project description
Subset Mixture Model (SMM)
SMM is an interpretable, uncertainty-aware regression method for data with categorical features. It learns a global convex mixture over subset-induced partition estimators—one per non-empty feature subset—and returns a full predictive distribution together with a transparent account of why each prediction has the value and the uncertainty it does.
Version 0.2 upgrades the method to be genuinely probabilistic:
- Cross-fitted weights — subset cell means are multiway target encodings, so the mixture weights are learned on out-of-fold statistics to avoid target leakage (small, high-order cells no longer get to memorize the response).
- Conjugate Student-t cells — each cell uses a Normal-Inverse-Gamma posterior, so its predictive component is a Student-t that widens and grows heavy tails in sparse cells (the plug-in Gaussian is the large-sample limit).
- Exact mixture inference — NLL, CDF, central intervals (by bisection) and CRPS are computed exactly from the mixture of Student-t components.
- Four-part predictive variance — within-cell noise, cell-estimation uncertainty, subset-resolution disagreement, and weight uncertainty (Laplace).
- Interpretation as diagnostics — the learned weights are summarized by entropy, effective number of subsets, order-level mass, and concentration, rather than claimed to be a sparse ANOVA decomposition.
Installation
pip install subset-mixture-model
Import as smm:
import smm
Requires torch>=2.0, numpy, pandas, scipy. plot_calibration also needs
matplotlib.
Quickstart
import numpy as np
import pandas as pd
from smm import SMM
# --- synthetic data: a main effect (region) + an interaction (season x tier) ---
rng = np.random.default_rng(0)
N = 2000
region = rng.integers(0, 2, N)
season = rng.integers(0, 3, N)
tier = rng.integers(0, 2, N)
y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})
train, val, test = df[:1400], df[1400:1700], df[1700:]
FEATURES, TARGET = ["region", "season", "tier"], "y"
# --- fit: cross-fitted weights + conjugate Student-t cells + Laplace UQ ---
model = SMM(FEATURES, TARGET, kappa0=1.0, lam=0.5).fit(train, val)
# --- point prediction and a full predictive distribution ---
mean = model.predict(test)
mean, std = model.predict_with_uncertainty(test)
lo, hi = model.interval(test, level=0.95) # exact mixture interval
y_test = test[TARGET].values
print("RMSE :", np.sqrt(np.mean((mean - y_test) ** 2)).round(3))
print("NLL :", round(model.nll(test, y_test), 3))
print("CRPS :", round(model.crps(test, y_test), 3))
print("cov95:", round(float(((y_test >= lo) & (y_test <= hi)).mean()), 3))
Why this prediction? Uncertainty decomposition
mean, std, aleatoric_std, epistemic_std = model.predict_with_uncertainty(
test, return_components=True
)
# std**2 == aleatoric**2 + epistemic**2 (within-cell noise vs. everything else)
Which interactions matter? Global diagnostics
print(model.weight_table(top_k=5)) # subsets ranked by learned weight
print(model.diagnostics()) # H, N_eff, HHI, order-level mass M_k
diagnostics() returns the concentration of the weight distribution. Diffuse
weights (large N_eff) are an honest signal that no sparse subset explanation
dominates—prediction draws on many comparable resolutions—while concentrated
weights identify a few dominant interactions.
Why this prediction? Local contributions
row = test.iloc[[0]]
print(model.explain(row)) # per-subset contributions; they sum to the prediction
What the model gives you
| Method | Returns |
|---|---|
SMM(features, target, ...) |
estimator; key knobs kappa0 (shrinkage), alpha0,lam (order-aware prior), K (folds), mode ("nig"/"plugin"), max_order |
.fit(df, val_df=None) |
cross-fit → optimize weights (early-stop on val_df) → Laplace |
.predict(df) |
predictive mean |
.predict_with_uncertainty(df, return_components=) |
mean, std [, aleatoric, epistemic] |
.nll / .crps |
exact mixture proper scores |
.interval(df, level) |
exact central interval via bisection on the mixture CDF |
.weight_table / .diagnostics / .explain |
interpretability outputs |
Lower-level building blocks are also exported: SubsetMaker, crossfit_components,
SubsetMixturePredictor, compute_posterior_covariance, predict_with_uncertainty,
NIGPrior, weight_table, weight_diagnostics, calibration_stats.
To reproduce the first-version plug-in behavior, use
SMM(..., mode="plugin", cross_fit=False).
Subset priors, raw categoricals, and reports (0.3)
Raw categorical columns are accepted directly — SMM encodes them internally and
preserves labels, so explain() and cell_table() show arena=MSG, day=Saturday
rather than integer codes.
Control the prior over subsets with a chainable SubsetPrior:
from smm import SMM, SubsetPrior
prior = (SubsetPrior.order_decay(lam=0.6, strength=0.5) # favor low order
.boost_subset(("arena", "day", "month"), factor=4.0) # domain belief
.penalize_order_ge(5, factor=0.2)
.exclude_contains("leaky_feature")) # hard governance mask
model = SMM(features, target, subset_prior=prior).fit(train, val)
Bases: uniform, order_decay, prefer_orders. Soft modifiers: boost_subset,
boost_features, boost_group_interactions, penalize_order_ge. Hard masks:
exclude_contains, exclude_order_ge, include_only.
Inspection & governance utilities:
| Method | What it shows |
|---|---|
.explain(row, centered=True, by_order=) |
per-subset (or per-order) local contributions as lift vs. baseline |
.cell_table(subset, sort_by=) |
the learned NIG cells: n, posterior mean, predictive sd, shrinkage |
.reliability_report(df) |
per-row valid-subset mass, effective cell support n_eff, fallback flag |
.weight_intervals() / .order_mass_intervals() |
Laplace posterior intervals on weights / order mass M_k |
.calibration_stats(df, y) / .pit(df, y) |
exact-mixture coverage and PIT values |
.interval(df, include_weight_uncertainty=True) |
intervals inflated for mixture-weight uncertainty |
compare_smm_variants(train, test, features, target) |
one-call table over plug-in / NIG / cross-fit / uniform / best-subset |
Method summary
For each non-empty feature subset s, SMM groups the training data by the values of s and models each resulting cell with a conjugate Normal-Inverse-Gamma posterior, giving a Student-t predictive component. A single global simplex weight vector π over all subsets is learned by MAP under an (optionally order-aware) Dirichlet prior, using out-of-fold cell statistics. The predictive distribution is the mixture ∑ₛ πₛ · tₛ, and uncertainty in π is propagated by a Laplace approximation in the low-dimensional logit space.
SMM is intended for problems whose predictive structure is concentrated in a modest number of naturally categorical features (full powerset for D ≤ 8; order-restricted for larger D). It is a transparent, uncertainty-aware alternative to gradient-boosted trees in that regime, not a general replacement.
Citation
@article{danielson2026smm,
title = {Subset Mixture Models: Interpretable Probabilistic Aggregation
of Partition Estimators for Categorical Regression},
author = {Danielson, Aaron John},
journal = {Under review},
year = {2026},
}
License
MIT © Aaron John Danielson
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file subset_mixture_model-0.3.0.tar.gz.
File metadata
- Download URL: subset_mixture_model-0.3.0.tar.gz
- Upload date:
- Size: 30.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f73d75786d8d5c0a22e599b720413bd03b91c1d9fb95b557772d9e1f17c340ba
|
|
| MD5 |
58b21473b5b315c68871956388498e3a
|
|
| BLAKE2b-256 |
a675ab78b66eb2554bd36e2596985a2a6c2982cb84fc7c64f21ee569a4017672
|
File details
Details for the file subset_mixture_model-0.3.0-py3-none-any.whl.
File metadata
- Download URL: subset_mixture_model-0.3.0-py3-none-any.whl
- Upload date:
- Size: 28.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4ad9c5feb39cde6a92fb1b6048051017ca514113d9b9166501781e708e91fa1a
|
|
| MD5 |
c598120b48afd57f5c51789d0479b6a5
|
|
| BLAKE2b-256 |
f509e332202eaffcffc70f65d2e30d68b1f718f9e1fd46f21089bae155ec60fc
|