Skip to main content

Supervised prototype-routing for interpretable tabular prediction

Project description

torch-supercluster

SuperCluster learns K prediction prototypes jointly with a supervised objective. Each example is routed to prototypes via cross-attention, and the prediction is an exact convex combination of prototype-level predictions.

This gives you a model that is simultaneously:

  • Accurate — matches or exceeds MLP accuracy on tabular benchmarks
  • Interpretable — each prototype maps to a readable "prediction archetype"
  • Self-analysing — over-specify K and the model tells you how many distinct prediction regimes the data contains
x → MLP encoder → cross-attention (→ prototype centers C) → soft routing π
prediction:  ŷ = σ(π · s)   [binary]
             softmax(π @ S)  [multi-class]
             π · s           [regression]

Installation

pip install torch-supercluster

Requires Python ≥ 3.9 and PyTorch ≥ 2.0.

Quick start

Binary classification

from supercluster import SuperCluster, center_diversity_loss
import torch, torch.nn as nn

model = SuperCluster(
    input_dim=10,
    embedding_dim=64,
    num_clusters=8,       # set larger than expected; surplus prototypes collapse
)

logits, cluster_weights = model(x)               # x: [N, 10]
loss = nn.BCEWithLogitsLoss()(logits, y)          # logits: [N, 1]
loss = loss + 0.1 * center_diversity_loss(model.centers)
loss.backward()

Multi-class classification

model = SuperCluster(
    input_dim=54, embedding_dim=128,
    num_clusters=8, num_classes=7,
)
logits, cluster_weights = model(x)               # logits: [N, 7]
loss = nn.CrossEntropyLoss()(logits, y_long)

Regression

model = SuperCluster(
    input_dim=10, embedding_dim=64, num_clusters=5,
)
preds, cluster_weights = model(x)                # preds: [N, 1], no sigmoid
loss = nn.MSELoss()(preds, y)

Reading prototype predictions

import torch.nn.functional as F

# Binary: prototype predicted probability
for k in range(model.num_clusters):
    prob = torch.sigmoid(model.prototype_scores[k]).item()
    print(f"Prototype {k}: P(y=1) = {prob:.3f}")

# Multi-class: prototype predicted class distribution
for k in range(model.num_clusters):
    probs = F.softmax(model.prototype_scores[k], dim=0)
    print(f"Prototype {k}: {probs.tolist()}")

How it works

SuperCluster separates two concerns that prior prototype models conflate:

Parameter Role
model.centers ∈ ℝ^{K×d} Routing geometry — where prototypes live in embedding space
model.prototype_scores ∈ ℝ^K or ℝ^{K×C} Prediction — each prototype's output value

The center-diversity loss pushes centers apart on the unit hypersphere without touching prediction scores, preventing the "overconfident prototype" pathology common when routing geometry and prediction are conflated.

Predictive regime discovery

Set K larger than you expect and examine which prototypes are occupied. The model concentrates assignment mass on exactly as many prototypes as the data has distinct prediction behaviours — surplus prototypes collapse to zero routing mass. This gives an operational measure of how many prediction regimes the data contains.

from supercluster import effective_prototype_count

_, cw = model(X_test)
k_eff = effective_prototype_count(cw)          # entropy-based count
active = model.active_prototypes(cw)           # list of occupied prototype indices
print(f"K_eff = {k_eff:.2f}, K_active = {len(active)}/{model.num_clusters}")

Model parameters

Parameter Default Description
input_dim required Raw feature dimension
embedding_dim required Latent dimension d
num_clusters required Number of prototypes K
num_classes 1 1 = binary/regression; C > 1 = multi-class
encoder_layers 4 MLP encoder depth
encoder_hidden_size 256 MLP encoder width
num_attn_heads 8 Cross-attention heads (must divide embedding_dim)
num_cross_attn_layers 2 Cross-attention depth L
dropout 0.1 Dropout rate

Training recommendations

Setting Recommendation
Optimizer AdamW or Adam, lr = 1e-3
Diversity weight λ_div = 0.1 — add λ_div * center_diversity_loss(model.centers) to main loss
K 1.5–3× your expected regime count; performance is flat across K
Patience Early-stop on validation loss with patience 50–100 epochs

Empirical results

Dataset MLP SuperCluster K_active
NBA shots (binary, 72k) 63.3% 63.4% 3/8
Bank Marketing (binary, 45k) 88.9% / 78.7% AUC 89.0% / 77.8% AUC 3/8
Adult Income (binary, 45k) 85.0% / 91.0% AUC 85.2% / 90.7% AUC 2/8
Credit Default (binary, 30k) 81.7% / 76.8% AUC 82.0% / 76.8% AUC 2/8
Covertype (7-class, 581k) 95.7% 95.5% 6/8

Citation

If you use this package in academic work, please cite:

@article{danielson2026supercluster,
  title   = {SuperCluster: Learning Prediction Prototypes via
             Target-Guided Cross-Attention Clustering},
  author  = {Danielson, Aaron John},
  journal = {Data Mining and Knowledge Discovery},
  year    = {2026},
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_supercluster-0.1.0.tar.gz (12.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_supercluster-0.1.0-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_supercluster-0.1.0.tar.gz.

File metadata

  • Download URL: torch_supercluster-0.1.0.tar.gz
  • Upload date:
  • Size: 12.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for torch_supercluster-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f7de8ffb7d892704012d738b7a63765d09df09a7aeb99685e67c9ff593826d16
MD5 833b090998812458efe2474667dc31ad
BLAKE2b-256 0528792c2258217aa18ee91ded400c85bcd109d9c55d126eed1aed710f27d064

See more details on using hashes here.

File details

Details for the file torch_supercluster-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_supercluster-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1a5e00920cf146d482de1b03e1ac39be494905db5f5446806d45c4a2f0214d06
MD5 6e89c7ad5b9cccb982117771ecd0b626
BLAKE2b-256 5f18408cfdd12765e1541785d33a88772daf5403848caf2f18feacd82b3eff00

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page