Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
Project description
Subset Mixture Model (SMM)
SMM is an interpretable, uncertainty-aware regression method for datasets with categorical features. It learns a weighted average of partition estimators—one per non-empty feature subset—and tells you exactly why each prediction has the value and uncertainty it does.
Key idea
Each feature subset s groups training data by its unique value combinations and stores the empirical mean and variance per group. SMM learns a single global weight vector π over all 2^D − 1 subsets that minimizes negative log-likelihood:
$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}s \cdot \hat{\mu}{m(s,\mathbf{x})}(s)$$
The learned weights directly answer: which feature combinations matter? Predictions are convex combinations of verifiable training-data statistics—no black box.
Oracle guarantee: the learned mixture achieves within log(|S|)/n of the best single-subset estimator in log-loss.
Installation
pip install subset-mixture-model
Import as smm:
import smm
Complete worked example
The example below is fully self-contained. It creates a synthetic dataset with known structure, trains SMM, makes calibrated predictions with uncertainty estimates, then uses the diagnostic tools to trace why each prediction looks the way it does.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from smm import (
SubsetMaker, SubsetWeightsModel, SubsetDataset,
subset_mixture_neg_log_posterior,
SubsetMixturePredictor,
compute_posterior_covariance,
predict_with_uncertainty,
coverage,
weight_table,
explain_prediction,
calibration_stats,
)
# ── 1. Synthetic data ──────────────────────────────────────────────────────────
#
# Three categorical features: region (2 values), season (3 values), tier (2 values).
# True signal: region drives a baseline; season × tier creates an interaction.
# SMM should discover both without being told which interactions to look for.
rng = np.random.default_rng(42)
N = 2000
region = rng.integers(0, 2, N)
season = rng.integers(0, 3, N)
tier = rng.integers(0, 2, N)
y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)
df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})
idx = rng.permutation(N)
n_train, n_val = int(0.70 * N), int(0.15 * N)
train_df = df.iloc[idx[:n_train]].reset_index(drop=True)
val_df = df.iloc[idx[n_train:n_train + n_val]].reset_index(drop=True)
test_df = df.iloc[idx[n_train + n_val:]].reset_index(drop=True)
CAT_COLS = ["region", "season", "tier"]
TARGET = "y"
# ── 2. Build the lookup table ──────────────────────────────────────────────────
#
# SubsetMaker enumerates all 2^3 − 1 = 7 non-empty feature subsets, groups the
# training data by value combinations within each subset, and stores the
# empirical (mean, variance) of the target per group.
subset_maker = SubsetMaker(train_df, CAT_COLS, [TARGET])
n_subsets = len(subset_maker.lookup)
print(f"Subsets: {n_subsets}") # → 7
# ── 3. Train the weight model ──────────────────────────────────────────────────
#
# SubsetWeightsModel holds a single logit vector η ∈ R^|S|.
# The training loss is negative log-posterior: NLL of the Gaussian mixture
# plus a Dirichlet prior (alpha > 1 discourages degenerate weights).
model = SubsetWeightsModel(n_subsets)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-4)
ALPHA = 1.1
train_loader = DataLoader(
SubsetDataset(train_df, CAT_COLS, [TARGET]), batch_size=256, shuffle=True
)
val_loader = DataLoader(
SubsetDataset(val_df, CAT_COLS, [TARGET]), batch_size=256, shuffle=False
)
best_val, no_improve, best_state = float("inf"), 0, None
for epoch in range(300):
model.train()
for x, y_batch in train_loader:
optimizer.zero_grad()
mus, variances, mask = subset_maker.batch_lookup(x)
subset_mixture_neg_log_posterior(
model(), y_batch, mus, variances, mask, alpha=ALPHA
).backward()
optimizer.step()
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, y_batch in val_loader:
mus, variances, mask = subset_maker.batch_lookup(x)
val_loss += subset_mixture_neg_log_posterior(
model(), y_batch, mus, variances, mask, alpha=ALPHA
).item()
val_loss /= len(val_loader)
if val_loss < best_val:
best_val, no_improve = val_loss, 0
best_state = {k: v.clone() for k, v in model.state_dict().items()}
else:
no_improve += 1
if no_improve >= 20 and epoch >= 50:
break
model.load_state_dict(best_state)
# ── 4. Predictor and posterior covariance ──────────────────────────────────────
#
# The Laplace approximation treats the MAP estimate η̂ as the center of a
# Gaussian, computes the Hessian of the loss there, and propagates uncertainty
# to the simplex via the softmax Jacobian: Σ_π = J H⁻¹ Jᵀ.
pi_hat = F.softmax(model.eta.detach(), dim=0)
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
sigma_pi = compute_posterior_covariance(
model, subset_maker, train_df, CAT_COLS, TARGET, alpha=ALPHA
)
# ── 5. Predict ─────────────────────────────────────────────────────────────────
y_mean, y_std, aleatoric_std, epistemic_std = predict_with_uncertainty(
predictor, sigma_pi, test_df, return_components=True
)
y_true = test_df[TARGET].values
print(f"Test RMSE: {np.sqrt(np.mean((y_mean - y_true)**2)):.3f}")
print(f"95% coverage: {coverage(y_true, y_mean, y_std, level=0.95):.3f}")
# ── 6. Diagnostic: which subsets drive predictions? ───────────────────────────
#
# weight_table() returns a DataFrame sorted by π_s descending.
# Because the true signal has a "region" main effect and a "season × tier"
# interaction, those two subsets should dominate.
wt = weight_table(subset_maker, pi_hat, top_k=5)
print("\nTop-5 subsets by weight:")
print(wt[["subset", "weight", "n_cells", "cumulative_weight"]].to_string(index=False))
# Expected output (approximately):
# subset weight n_cells cumulative_weight
# (region,) 0.41 2 0.41
# (season, tier) 0.33 6 0.74
# (region, season) 0.11 ... 0.85
# ── 7. Diagnostic: why this prediction for one test point? ────────────────────
#
# explain_prediction() shows every subset that has a valid training cell for
# this point, along with its cell statistics, renormalized weight, and additive
# contribution to the final prediction.
row = test_df.iloc[[0]]
exp = explain_prediction(predictor, row)
print(f"\nPrediction: {exp.attrs['predicted_mean']:.3f} "
f"(true: {y_true[0]:.3f})")
print(f"Uncertainty: total={y_std[0]:.3f} "
f"aleatoric={aleatoric_std[0]:.3f} epistemic={epistemic_std[0]:.3f}")
print("\nPer-subset breakdown:")
print(exp[["subset", "cell_mean", "masked_weight", "contribution"]].to_string(index=False))
# Each row is a training-data statistic you can verify from train_df directly.
# The prediction equals the sum of the "contribution" column.
# ── 8. Diagnostic: are the intervals well-calibrated? ────────────────────────
cal = calibration_stats(y_true, y_mean, y_std)
print("\nCalibration:")
print(cal.to_string(index=False))
# Values near nominal → well-calibrated.
# SMM is typically slightly conservative (empirical ≥ nominal).
Understanding the diagnostics
Weight table
subset weight n_cells cumulative_weight
(region,) 0.41 2 0.41
(season, tier) 0.33 6 0.74
...
A high weight on (region,) means knowing the region alone explains a large share of variance. A high weight on (season, tier) means the season–tier interaction is informative beyond either feature alone. Subsets with negligible weight contribute little—the model has automatically selected which interactions matter.
Prediction breakdown
subset cell_mean masked_weight contribution
(region,) 8.1 0.52 4.21
(season, tier) 6.8 0.36 2.45
(region, season) 7.9 0.09 0.71
Every row traces back to a specific group of training examples. The predicted value equals the sum of the contribution column. This is not a black box.
Uncertainty decomposition
| Component | Source | Grows when… |
|---|---|---|
| Aleatoric | Variance within matched training cells | Training cell has high spread |
| Epistemic | Laplace posterior over mixture weights | Data is sparse; fewer subsets are active |
Calibration table
nominal empirical
0.50 0.51
0.80 0.82
0.95 0.96
Values near the diagonal mean stated intervals contain the true value at the stated rate. A reliability diagram can be plotted with plot_calibration() (requires matplotlib).
API reference
Data and training
| Symbol | Description |
|---|---|
SubsetMaker(df, cat_cols, [target]) |
Build powerset lookup table from training data |
SubsetWeightsModel(n_subsets) |
Trainable logit parameter η of length |S| |
SubsetDataset(df, cat_cols, [target]) |
PyTorch Dataset wrapping a DataFrame |
subset_mixture_neg_log_posterior(logits, y, mus, vars, mask, alpha) |
Training loss (NLL + Dirichlet prior) |
subset_mixture_mse(logits, y, mus, mask) |
MSE loss for warmup or debugging |
Inference
| Symbol | Description |
|---|---|
SubsetMixturePredictor(subset_maker, pi_hat) |
Inference wrapper |
predictor.predict(df, return_debug=False) |
Point predictions; return_debug=True also returns per-example weight matrix [B, |S|] and fallback mask |
Uncertainty
| Symbol | Description |
|---|---|
compute_posterior_covariance(model, subset_maker, train_df, cat_cols, target, alpha) |
Laplace approximation → Σ_π [S, S] |
predict_with_uncertainty(predictor, sigma_pi, df, return_components=False) |
Mean + total std; optionally aleatoric and epistemic stds |
coverage(y_true, y_mean, y_std, level=0.95) |
Empirical interval coverage at given level |
Diagnostics
| Symbol | Description |
|---|---|
weight_table(subset_maker, pi_hat, top_k=None) |
DataFrame of subsets ranked by learned weight |
explain_prediction(predictor, row_df) |
Per-subset contribution breakdown for one test point |
calibration_stats(y_true, y_mean, y_std, levels=None) |
Empirical vs. nominal coverage at multiple levels |
plot_calibration(y_true, y_mean, y_std, levels=None, ax=None) |
Reliability diagram (requires matplotlib) |
Requirements
Core: torch >= 2.0, numpy >= 1.24, pandas >= 2.0, scipy >= 1.11
plot_calibration additionally requires matplotlib.
Citation
@article{danielson2025smm,
title = {Subset Mixture Model: Interpretable Aggregation of Partition Estimators},
author = {Danielson, Aaron John},
journal = {Transactions on Machine Learning Research},
year = {2025},
}
License
MIT © Aaron John Danielson
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file subset_mixture_model-0.1.2.tar.gz.
File metadata
- Download URL: subset_mixture_model-0.1.2.tar.gz
- Upload date:
- Size: 20.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6903f584f72ecc8d52c44019ff7649f09d34a42fd4d43da6a213748b23e62635
|
|
| MD5 |
3f64c70a6f071b43f08fa5377e14424b
|
|
| BLAKE2b-256 |
a7c24a24c61ebe5e18519a0ff8022cecee6aa99faeaf32739f1b2f7b73e5223b
|
File details
Details for the file subset_mixture_model-0.1.2-py3-none-any.whl.
File metadata
- Download URL: subset_mixture_model-0.1.2-py3-none-any.whl
- Upload date:
- Size: 16.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6ab96737941a731e843c5e6643644daecd001e1e79eb73880649aa70a034252d
|
|
| MD5 |
840c81fdf5690ce71ee1243ccbc94f96
|
|
| BLAKE2b-256 |
34927bf6a1adfbe2706a3330721eb2f866ea648b8448986040a785c11e563253
|