Skip to main content

Interpretable empirical-Bayes aggregation of partition estimators for categorical regression

Project description

Subset Mixture Model (SMM)

PyPI version License: MIT

SMM is an interpretable, uncertainty-aware regression method for data with categorical features. It learns a global convex mixture over subset-induced partition estimators—one per non-empty feature subset—and returns a full predictive distribution together with a transparent account of why each prediction has the value and the uncertainty it does.

Version 0.2 upgrades the method to be genuinely probabilistic:

  • Cross-fitted weights — subset cell means are multiway target encodings, so the mixture weights are learned on out-of-fold statistics to avoid target leakage (small, high-order cells no longer get to memorize the response).
  • Conjugate Student-t cells — each cell uses a Normal-Inverse-Gamma posterior, so its predictive component is a Student-t that widens and grows heavy tails in sparse cells (the plug-in Gaussian is the large-sample limit).
  • Exact mixture inference — NLL, CDF, central intervals (by bisection) and CRPS are computed exactly from the mixture of Student-t components.
  • Four-part predictive variance — within-cell noise, cell-estimation uncertainty, subset-resolution disagreement, and weight uncertainty (Laplace).
  • Interpretation as diagnostics — the learned weights are summarized by entropy, effective number of subsets, order-level mass, and concentration, rather than claimed to be a sparse ANOVA decomposition.

Installation

pip install subset-mixture-model

Import as smm:

import smm

Requires torch>=2.0, numpy, pandas, scipy. plot_calibration also needs matplotlib.


Quickstart

import numpy as np
import pandas as pd
from smm import SMM

# --- synthetic data: a main effect (region) + an interaction (season x tier) ---
rng = np.random.default_rng(0)
N = 2000
region = rng.integers(0, 2, N)
season = rng.integers(0, 3, N)
tier   = rng.integers(0, 2, N)
y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})

train, val, test = df[:1400], df[1400:1700], df[1700:]
FEATURES, TARGET = ["region", "season", "tier"], "y"

# --- fit: cross-fitted weights + conjugate Student-t cells + Laplace UQ ---
model = SMM(FEATURES, TARGET, kappa0=1.0, lam=0.5).fit(train, val)

# --- point prediction and a full predictive distribution ---
mean = model.predict(test)
mean, std = model.predict_with_uncertainty(test)
lo, hi = model.interval(test, level=0.95)          # exact mixture interval

y_test = test[TARGET].values
print("RMSE :", np.sqrt(np.mean((mean - y_test) ** 2)).round(3))
print("NLL  :", round(model.nll(test, y_test), 3))
print("CRPS :", round(model.crps(test, y_test), 3))
print("cov95:", round(float(((y_test >= lo) & (y_test <= hi)).mean()), 3))

Why this prediction? Uncertainty decomposition

mean, std, aleatoric_std, epistemic_std = model.predict_with_uncertainty(
    test, return_components=True
)
# std**2 == aleatoric**2 + epistemic**2  (within-cell noise vs. everything else)

Which interactions matter? Global diagnostics

print(model.weight_table(top_k=5))     # subsets ranked by learned weight
print(model.diagnostics())             # H, N_eff, HHI, order-level mass M_k

diagnostics() returns the concentration of the weight distribution. Diffuse weights (large N_eff) are an honest signal that no sparse subset explanation dominates—prediction draws on many comparable resolutions—while concentrated weights identify a few dominant interactions.

Why this prediction? Local contributions

row = test.iloc[[0]]
print(model.explain(row))     # per-subset contributions; they sum to the prediction

What the model gives you

Method Returns
SMM(features, target, ...) estimator; key knobs kappa0 (shrinkage), alpha0,lam (order-aware prior), K (folds), mode ("nig"/"plugin"), max_order
.fit(df, val_df=None) cross-fit → optimize weights (early-stop on val_df) → Laplace
.predict(df) predictive mean
.predict_with_uncertainty(df, return_components=) mean, std [, aleatoric, epistemic]
.nll / .crps exact mixture proper scores
.interval(df, level) exact central interval via bisection on the mixture CDF
.weight_table / .diagnostics / .explain interpretability outputs

Lower-level building blocks are also exported: SubsetMaker, crossfit_components, SubsetMixturePredictor, compute_posterior_covariance, predict_with_uncertainty, NIGPrior, weight_table, weight_diagnostics, calibration_stats.

To reproduce the first-version plug-in behavior, use SMM(..., mode="plugin", cross_fit=False).


Subset priors, raw categoricals, and reports (0.3)

Raw categorical columns are accepted directly — SMM encodes them internally and preserves labels, so explain() and cell_table() show arena=MSG, day=Saturday rather than integer codes.

Control the prior over subsets with a chainable SubsetPrior:

from smm import SMM, SubsetPrior

prior = (SubsetPrior.order_decay(lam=0.6, strength=0.5)          # favor low order
         .boost_subset(("arena", "day", "month"), factor=4.0)    # domain belief
         .penalize_order_ge(5, factor=0.2)
         .exclude_contains("leaky_feature"))                     # hard governance mask

model = SMM(features, target, subset_prior=prior).fit(train, val)

Bases: uniform, order_decay, prefer_orders. Soft modifiers: boost_subset, boost_features, boost_group_interactions, penalize_order_ge. Hard masks: exclude_contains, exclude_order_ge, include_only.

Inspection & governance utilities:

Method What it shows
.explain(row, centered=True, by_order=) per-subset (or per-order) local contributions as lift vs. baseline
.cell_table(subset, sort_by=) the learned NIG cells: n, posterior mean, predictive sd, shrinkage
.reliability_report(df) per-row valid-subset mass, effective cell support n_eff, fallback flag
.weight_intervals() / .order_mass_intervals() Laplace posterior intervals on weights / order mass M_k
.calibration_stats(df, y) / .pit(df, y) exact-mixture coverage and PIT values
.interval(df, include_weight_uncertainty=True) intervals inflated for mixture-weight uncertainty
compare_smm_variants(train, test, features, target) one-call table over plug-in / NIG / cross-fit / uniform / best-subset

Method summary

For each non-empty feature subset s, SMM groups the training data by the values of s and models each resulting cell with a conjugate Normal-Inverse-Gamma posterior, giving a Student-t predictive component. A single global simplex weight vector π over all subsets is learned by MAP under an (optionally order-aware) Dirichlet prior, using out-of-fold cell statistics. The predictive distribution is the mixture ∑ₛ πₛ · tₛ, and uncertainty in π is propagated by a Laplace approximation in the low-dimensional logit space.

SMM is intended for problems whose predictive structure is concentrated in a modest number of naturally categorical features (full powerset for D ≤ 8; order-restricted for larger D). It is a transparent, uncertainty-aware alternative to gradient-boosted trees in that regime, not a general replacement.


Citation

@article{danielson2026smm,
  title   = {Subset Mixture Models: Interpretable Probabilistic Aggregation
             of Partition Estimators for Categorical Regression},
  author  = {Danielson, Aaron John},
  journal = {Under review},
  year    = {2026},
}

License

MIT © Aaron John Danielson

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subset_mixture_model-0.3.0.tar.gz (30.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

subset_mixture_model-0.3.0-py3-none-any.whl (28.5 kB view details)

Uploaded Python 3

File details

Details for the file subset_mixture_model-0.3.0.tar.gz.

File metadata

  • Download URL: subset_mixture_model-0.3.0.tar.gz
  • Upload date:
  • Size: 30.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.13

File hashes

Hashes for subset_mixture_model-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f73d75786d8d5c0a22e599b720413bd03b91c1d9fb95b557772d9e1f17c340ba
MD5 58b21473b5b315c68871956388498e3a
BLAKE2b-256 a675ab78b66eb2554bd36e2596985a2a6c2982cb84fc7c64f21ee569a4017672

See more details on using hashes here.

File details

Details for the file subset_mixture_model-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for subset_mixture_model-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4ad9c5feb39cde6a92fb1b6048051017ca514113d9b9166501781e708e91fa1a
MD5 c598120b48afd57f5c51789d0479b6a5
BLAKE2b-256 f509e332202eaffcffc70f65d2e30d68b1f718f9e1fd46f21089bae155ec60fc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page