Skip to main content

Interpretable empirical-Bayes aggregation of partition estimators for categorical regression

Project description

Subset Mixture Model (SMM)

PyPI version License: MIT

SMM is an interpretable, uncertainty-aware regression method for datasets with categorical features. It learns a weighted average of partition estimators—one per non-empty feature subset—and tells you exactly why each prediction has the value and uncertainty it does.


Key idea

Each feature subset s groups training data by its unique value combinations and stores the empirical mean and variance per group. SMM learns a single global weight vector π over all 2^D − 1 subsets that minimizes negative log-likelihood:

$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}s \cdot \hat{\mu}{m(s,\mathbf{x})}(s)$$

The learned weights directly answer: which feature combinations matter? Predictions are convex combinations of verifiable training-data statistics—no black box.

Oracle guarantee: the learned mixture achieves within log(|S|)/n of the best single-subset estimator in log-loss.


Installation

pip install subset-mixture-model

Import as smm:

import smm

Complete worked example

The example below is fully self-contained. It creates a synthetic dataset with known structure, trains SMM, makes calibrated predictions with uncertainty estimates, then uses the diagnostic tools to trace why each prediction looks the way it does.

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from smm import (
    SubsetMaker, SubsetWeightsModel, SubsetDataset,
    subset_mixture_neg_log_posterior,
    SubsetMixturePredictor,
    compute_posterior_covariance,
    predict_with_uncertainty,
    coverage,
    weight_table,
    explain_prediction,
    calibration_stats,
)

# ── 1. Synthetic data ──────────────────────────────────────────────────────────
#
# Three categorical features: region (2 values), season (3 values), tier (2 values).
# True signal: region drives a baseline; season × tier creates an interaction.
# SMM should discover both without being told which interactions to look for.

rng = np.random.default_rng(42)
N = 2000

region = rng.integers(0, 2, N)
season = rng.integers(0, 3, N)
tier   = rng.integers(0, 2, N)

y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)

df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})

idx = rng.permutation(N)
n_train, n_val = int(0.70 * N), int(0.15 * N)
train_df = df.iloc[idx[:n_train]].reset_index(drop=True)
val_df   = df.iloc[idx[n_train:n_train + n_val]].reset_index(drop=True)
test_df  = df.iloc[idx[n_train + n_val:]].reset_index(drop=True)

CAT_COLS = ["region", "season", "tier"]
TARGET   = "y"

# ── 2. Build the lookup table ──────────────────────────────────────────────────
#
# SubsetMaker enumerates all 2^3 − 1 = 7 non-empty feature subsets, groups the
# training data by value combinations within each subset, and stores the
# empirical (mean, variance) of the target per group.

subset_maker = SubsetMaker(train_df, CAT_COLS, [TARGET])
n_subsets = len(subset_maker.lookup)
print(f"Subsets: {n_subsets}")   # → 7

# ── 3. Train the weight model ──────────────────────────────────────────────────
#
# SubsetWeightsModel holds a single logit vector η ∈ R^|S|.
# The training loss is negative log-posterior: NLL of the Gaussian mixture
# plus a Dirichlet prior (alpha > 1 discourages degenerate weights).

model     = SubsetWeightsModel(n_subsets)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-4)
ALPHA     = 1.1

train_loader = DataLoader(
    SubsetDataset(train_df, CAT_COLS, [TARGET]), batch_size=256, shuffle=True
)
val_loader = DataLoader(
    SubsetDataset(val_df, CAT_COLS, [TARGET]), batch_size=256, shuffle=False
)

best_val, no_improve, best_state = float("inf"), 0, None

for epoch in range(300):
    model.train()
    for x, y_batch in train_loader:
        optimizer.zero_grad()
        mus, variances, mask = subset_maker.batch_lookup(x)
        subset_mixture_neg_log_posterior(
            model(), y_batch, mus, variances, mask, alpha=ALPHA
        ).backward()
        optimizer.step()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y_batch in val_loader:
            mus, variances, mask = subset_maker.batch_lookup(x)
            val_loss += subset_mixture_neg_log_posterior(
                model(), y_batch, mus, variances, mask, alpha=ALPHA
            ).item()
    val_loss /= len(val_loader)

    if val_loss < best_val:
        best_val, no_improve = val_loss, 0
        best_state = {k: v.clone() for k, v in model.state_dict().items()}
    else:
        no_improve += 1
    if no_improve >= 20 and epoch >= 50:
        break

model.load_state_dict(best_state)

# ── 4. Predictor and posterior covariance ──────────────────────────────────────
#
# The Laplace approximation treats the MAP estimate η̂ as the center of a
# Gaussian, computes the Hessian of the loss there, and propagates uncertainty
# to the simplex via the softmax Jacobian: Σ_π = J H⁻¹ Jᵀ.

pi_hat    = F.softmax(model.eta.detach(), dim=0)
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
sigma_pi  = compute_posterior_covariance(
    model, subset_maker, train_df, CAT_COLS, TARGET, alpha=ALPHA
)

# ── 5. Predict ─────────────────────────────────────────────────────────────────

y_mean, y_std, aleatoric_std, epistemic_std = predict_with_uncertainty(
    predictor, sigma_pi, test_df, return_components=True
)
y_true = test_df[TARGET].values

print(f"Test RMSE:     {np.sqrt(np.mean((y_mean - y_true)**2)):.3f}")
print(f"95% coverage:  {coverage(y_true, y_mean, y_std, level=0.95):.3f}")

# ── 6. Diagnostic: which subsets drive predictions? ───────────────────────────
#
# weight_table() returns a DataFrame sorted by π_s descending.
# Because the true signal has a "region" main effect and a "season × tier"
# interaction, those two subsets should dominate.

wt = weight_table(subset_maker, pi_hat, top_k=5)
print("\nTop-5 subsets by weight:")
print(wt[["subset", "weight", "n_cells", "cumulative_weight"]].to_string(index=False))
# Expected output (approximately):
#          subset  weight  n_cells  cumulative_weight
#        (region,)    0.41        2               0.41
#   (season, tier)    0.33        6               0.74
#  (region, season)   0.11       ...              0.85

# ── 7. Diagnostic: why this prediction for one test point? ────────────────────
#
# explain_prediction() shows every subset that has a valid training cell for
# this point, along with its cell statistics, renormalized weight, and additive
# contribution to the final prediction.

row = test_df.iloc[[0]]
exp = explain_prediction(predictor, row)

print(f"\nPrediction: {exp.attrs['predicted_mean']:.3f}  "
      f"(true: {y_true[0]:.3f})")
print(f"Uncertainty: total={y_std[0]:.3f}  "
      f"aleatoric={aleatoric_std[0]:.3f}  epistemic={epistemic_std[0]:.3f}")
print("\nPer-subset breakdown:")
print(exp[["subset", "cell_mean", "masked_weight", "contribution"]].to_string(index=False))
# Each row is a training-data statistic you can verify from train_df directly.
# The prediction equals the sum of the "contribution" column.

# ── 8. Diagnostic: are the intervals well-calibrated? ────────────────────────

cal = calibration_stats(y_true, y_mean, y_std)
print("\nCalibration:")
print(cal.to_string(index=False))
# Values near nominal → well-calibrated.
# SMM is typically slightly conservative (empirical ≥ nominal).

Understanding the diagnostics

Weight table

      subset  weight  n_cells  cumulative_weight
    (region,)   0.41        2               0.41
(season, tier)   0.33        6               0.74
        ...

A high weight on (region,) means knowing the region alone explains a large share of variance. A high weight on (season, tier) means the season–tier interaction is informative beyond either feature alone. Subsets with negligible weight contribute little—the model has automatically selected which interactions matter.

Prediction breakdown

         subset  cell_mean  masked_weight  contribution
       (region,)       8.1           0.52          4.21
   (season, tier)       6.8           0.36          2.45
  (region, season)      7.9           0.09          0.71

Every row traces back to a specific group of training examples. The predicted value equals the sum of the contribution column. This is not a black box.

Uncertainty decomposition

Component Source Grows when…
Aleatoric Variance within matched training cells Training cell has high spread
Epistemic Laplace posterior over mixture weights Data is sparse; fewer subsets are active

Calibration table

 nominal  empirical
    0.50       0.51
    0.80       0.82
    0.95       0.96

Values near the diagonal mean stated intervals contain the true value at the stated rate. A reliability diagram can be plotted with plot_calibration() (requires matplotlib).


API reference

Data and training

Symbol Description
SubsetMaker(df, cat_cols, [target]) Build powerset lookup table from training data
SubsetWeightsModel(n_subsets) Trainable logit parameter η of length |S|
SubsetDataset(df, cat_cols, [target]) PyTorch Dataset wrapping a DataFrame
subset_mixture_neg_log_posterior(logits, y, mus, vars, mask, alpha) Training loss (NLL + Dirichlet prior)
subset_mixture_mse(logits, y, mus, mask) MSE loss for warmup or debugging

Inference

Symbol Description
SubsetMixturePredictor(subset_maker, pi_hat) Inference wrapper
predictor.predict(df, return_debug=False) Point predictions; return_debug=True also returns per-example weight matrix [B, |S|] and fallback mask

Uncertainty

Symbol Description
compute_posterior_covariance(model, subset_maker, train_df, cat_cols, target, alpha) Laplace approximation → Σ_π [S, S]
predict_with_uncertainty(predictor, sigma_pi, df, return_components=False) Mean + total std; optionally aleatoric and epistemic stds
coverage(y_true, y_mean, y_std, level=0.95) Empirical interval coverage at given level

Diagnostics

Symbol Description
weight_table(subset_maker, pi_hat, top_k=None) DataFrame of subsets ranked by learned weight
explain_prediction(predictor, row_df) Per-subset contribution breakdown for one test point
calibration_stats(y_true, y_mean, y_std, levels=None) Empirical vs. nominal coverage at multiple levels
plot_calibration(y_true, y_mean, y_std, levels=None, ax=None) Reliability diagram (requires matplotlib)

Requirements

Core: torch >= 2.0, numpy >= 1.24, pandas >= 2.0, scipy >= 1.11

plot_calibration additionally requires matplotlib.


Citation

@article{danielson2025smm,
  title   = {Subset Mixture Model: Interpretable Aggregation of Partition Estimators},
  author  = {Danielson, Aaron John},
  journal = {Transactions on Machine Learning Research},
  year    = {2025},
}

License

MIT © Aaron John Danielson

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subset_mixture_model-0.1.2.tar.gz (20.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

subset_mixture_model-0.1.2-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file subset_mixture_model-0.1.2.tar.gz.

File metadata

  • Download URL: subset_mixture_model-0.1.2.tar.gz
  • Upload date:
  • Size: 20.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for subset_mixture_model-0.1.2.tar.gz
Algorithm Hash digest
SHA256 6903f584f72ecc8d52c44019ff7649f09d34a42fd4d43da6a213748b23e62635
MD5 3f64c70a6f071b43f08fa5377e14424b
BLAKE2b-256 a7c24a24c61ebe5e18519a0ff8022cecee6aa99faeaf32739f1b2f7b73e5223b

See more details on using hashes here.

File details

Details for the file subset_mixture_model-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for subset_mixture_model-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6ab96737941a731e843c5e6643644daecd001e1e79eb73880649aa70a034252d
MD5 840c81fdf5690ce71ee1243ccbc94f96
BLAKE2b-256 34927bf6a1adfbe2706a3330721eb2f866ea648b8448986040a785c11e563253

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page