Skip to main content

Interpretable empirical-Bayes aggregation of partition estimators for categorical regression

Project description

Subset Mixture Model (SMM)

PyPI version License: MIT

SMM is an interpretable, empirical-Bayes method for regression on datasets with categorical features. It aggregates partition-based conditional-mean estimators over all non-empty feature subsets using learned simplex weights, adaptively balancing bias and variance across partition granularities.

Key idea

Each feature subset $s$ induces a partition of the covariate space and a natural estimator of the conditional expectation — its empirical cell mean. SMM learns a convex combination of these estimators:

$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}s \cdot \hat{\mu}{m(s,\mathbf{x})}(s)$$

The learned weights $\hat{\pi}_s$ are directly interpretable: they reveal which feature interactions drive predictions on average. Uncertainty is propagated from the MAP weight estimates via a Laplace approximation, yielding aleatoric/epistemic decompositions without post-hoc calibration.

Installation

pip install subset-mixture-model

Or from source:

git clone https://github.com/aaronjdanielson/subset-mixture-model
cd subset-mixture-model
pip install -e .

Quick start

import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from smm import (
    SubsetMaker, SubsetWeightsModel, SubsetDataset,
    subset_mixture_neg_log_posterior, SubsetMixturePredictor,
    compute_posterior_covariance, predict_with_uncertainty,
)

# --- your data (integer-coded categorical features) ---
train_df = pd.read_csv("train.csv")
val_df   = pd.read_csv("val.csv")
test_df  = pd.read_csv("test.csv")

cat_cols = ["feature_a", "feature_b", "feature_c"]
target   = "y"

# --- build lookup table ---
subset_maker = SubsetMaker(train_df, cat_cols, [target])
n_subsets = len(subset_maker.lookup)

# --- train ---
model     = SubsetWeightsModel(n_subsets)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
loader    = DataLoader(SubsetDataset(train_df, cat_cols, [target]),
                       batch_size=64, shuffle=True)

for epoch in range(100):
    for x, y in loader:
        optimizer.zero_grad()
        mus, variances, mask = subset_maker.batch_lookup(x)
        loss = subset_mixture_neg_log_posterior(
            model(), y, mus, variances, mask, alpha=1.1)
        loss.backward()
        optimizer.step()

# --- predict with uncertainty ---
pi_hat    = F.softmax(model.eta.detach(), dim=0)
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
sigma_pi  = compute_posterior_covariance(
    model, subset_maker, train_df, cat_cols, target, alpha=1.1)

y_mean, y_std = predict_with_uncertainty(predictor, sigma_pi, test_df)
# y_mean: point predictions
# y_std:  total predictive standard deviation (aleatoric + epistemic)

Interpretability

import numpy as np

subsets = list(subset_maker.lookup.keys())
top_idx = np.argsort(pi_hat.numpy())[::-1][:10]

for rank, i in enumerate(top_idx):
    print(f"{rank+1:2d}. {subsets[i]}  π={pi_hat[i]:.4f}")

Features

  • Interpretable by construction: learned weights reveal which feature interactions matter
  • Principled uncertainty: aleatoric/epistemic decomposition via Laplace approximation
  • Efficient training: only $2^D - 1$ logits optimized; lookup table precomputed once
  • No post-hoc calibration: well-calibrated predictive intervals out of the box
  • Scalable to D ≤ 15 features; $k$-way truncation available for larger $D$

Datasets supported

Any tabular dataset with integer-coded (or string, with encoding) categorical features and a continuous target.

Citation

@article{danielson2025smm,
  title   = {Subset Mixture Model: Interpretable Empirical-Bayes Aggregation
             of Partition Estimators for Categorical Regression},
  author  = {Danielson, Aaron John},
  journal = {Machine Learning},
  year    = {2025},
  note    = {Under review}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subset_mixture_model-0.1.1.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

subset_mixture_model-0.1.1-py3-none-any.whl (11.1 kB view details)

Uploaded Python 3

File details

Details for the file subset_mixture_model-0.1.1.tar.gz.

File metadata

  • Download URL: subset_mixture_model-0.1.1.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for subset_mixture_model-0.1.1.tar.gz
Algorithm Hash digest
SHA256 f4a0769e225e9a16323a48bb573f090ee0df4742c33d83f78ca2db27961761dc
MD5 ea7964f9c81aff23c2c4d55103f7c3a7
BLAKE2b-256 b6e27ad3bcd0f6fe44689a26ba60823fd770faf4494e31a1b15a790db26db872

See more details on using hashes here.

File details

Details for the file subset_mixture_model-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for subset_mixture_model-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0ed8f9a73342118913849e12d5fd69b9755254425a8dd616340e3ef2e2be027b
MD5 aef1e61d8888bced9ef9768892fcf7b4
BLAKE2b-256 616c351073c6b3a759079069b8dda6337cb6f01d8e125332e5f63c3d99e21c4b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page