Skip to main content

Interpretable empirical-Bayes aggregation of partition estimators for categorical regression

Project description

Subset Mixture Model (SMM)

SMM is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.

Key idea

Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:

$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}s \cdot \hat{\mu}{m(s,\mathbf{x})}(s)$$

The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.

Installation

pip install subset-mixture-model

Or from source:

git clone https://github.com/aaronjdanielson/subset-mixture-model
cd subset-mixture-model
pip install -e .

Quick start

import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from smm import (
    SubsetMaker, SubsetWeightsModel, SubsetDataset,
    subset_mixture_neg_log_posterior, SubsetMixturePredictor,
    compute_posterior_covariance, predict_with_uncertainty,
)

# --- your data (integer-coded categorical features) ---
train_df = pd.read_csv("train.csv")
val_df   = pd.read_csv("val.csv")
test_df  = pd.read_csv("test.csv")

cat_cols = ["feature_a", "feature_b", "feature_c"]
target   = "y"

# --- build lookup table ---
subset_maker = SubsetMaker(train_df, cat_cols, [target])
n_subsets = len(subset_maker.lookup)

# --- train ---
model     = SubsetWeightsModel(n_subsets)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
loader    = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
                       batch_size=64, shuffle=True)

for epoch in range(100):
    for x, y in loader:
        optimizer.zero_grad()
        mus, variances, mask = subset_maker.batch_lookup(x)
        loss = subset_mixture_neg_log_posterior(
            model(), y, mus, variances, mask, alpha=1.1)
        loss.backward()
        optimizer.step()

# --- predict with uncertainty ---
pi_hat    = F.softmax(model.eta.detach(), dim=0)
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
sigma_pi  = compute_posterior_covariance(
    model, subset_maker, train_df, cat_cols, target, alpha=1.1)

y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
# y_mean: point predictions
# y_std:  total predictive standard deviation (aleatoric + epistemic)

Interpretability

import numpy as np

subsets = list(subset_maker.lookup.keys())
top_idx = np.argsort(pi_hat.numpy())[::-1][:10]

for rank, i in enumerate(top_idx):
    print(f"{rank+1:2d}. {subsets[i]}  π={pi_hat[i]:.4f}")

Features

  • Interpretable by construction: learned weights reveal which feature interactions matter
  • Principled uncertainty: aleatoric/epistemic decomposition via Laplace approximation
  • Efficient training: only $2^D - 1$ logits optimized; lookup table precomputed once
  • No post-hoc calibration: well-calibrated predictive intervals out of the box
  • Scalable to D ≤ 15 features; $k$-way truncation available for larger $D$

Datasets supported

Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.

Citation

@article{danielson2025smm,
  title   = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
             of Partition Estimators for Categorical Regression},
  author  = {Danielson, Aaron John},
  journal = {Machine Learning},
  year    = {2025},
  note    = {Under review}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subset_mixture_model-0.1.0.tar.gz (13.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

subset_mixture_model-0.1.0-py3-none-any.whl (10.9 kB view details)

Uploaded Python 3

File details

Details for the file subset_mixture_model-0.1.0.tar.gz.

File metadata

  • Download URL: subset_mixture_model-0.1.0.tar.gz
  • Upload date:
  • Size: 13.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.13

File hashes

Hashes for subset_mixture_model-0.1.0.tar.gz
Algorithm Hash digest
SHA256 98c06e13525ce093764a17ace6e03f645c86b2b3d44b6ab7882c456fe9921b72
MD5 2a152e98a0ef2bd4617fcc3c54df2edd
BLAKE2b-256 34da2f92072111cefbb9e2b54ef45db284c15db8cd635962cb76d54493cc74fa

See more details on using hashes here.

File details

Details for the file subset_mixture_model-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for subset_mixture_model-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9cc54c9ec5edcfc96f14bd24145b87d1d0cec01d33a099ae0eb56c5698d494d3
MD5 39e7bf422c33825ed961f5af701b7a67
BLAKE2b-256 acca92aca3cab7bb4f7d07dd0d9ac7efa969ee94818cf8f0f616e36fecafe8f0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page