Supervised prototype-routing for interpretable tabular prediction
Project description
torch-supercluster
SuperCluster learns K prediction prototypes jointly with a supervised objective. Each example is routed to prototypes via cross-attention, and the prediction is an exact convex combination of prototype-level predictions.
This gives you a model that is simultaneously:
- Accurate — matches or exceeds MLP accuracy on tabular benchmarks
- Interpretable — each prototype maps to a readable "prediction archetype"
- Self-analysing — over-specify K and the model tells you how many distinct prediction regimes the data contains
x → MLP encoder → cross-attention (→ prototype centers C) → soft routing π
prediction: ŷ = σ(π · s) [binary]
softmax(π @ S) [multi-class]
π · s [regression]
Installation
pip install torch-supercluster
Requires Python ≥ 3.9 and PyTorch ≥ 2.0.
Quick start
Binary classification
from supercluster import SuperCluster, center_diversity_loss
import torch, torch.nn as nn
model = SuperCluster(
input_dim=10,
embedding_dim=64,
num_clusters=8, # set larger than expected; surplus prototypes collapse
)
logits, cluster_weights = model(x) # x: [N, 10]
loss = nn.BCEWithLogitsLoss()(logits, y) # logits: [N, 1]
loss = loss + 0.1 * center_diversity_loss(model.centers)
loss.backward()
Multi-class classification
model = SuperCluster(
input_dim=54, embedding_dim=128,
num_clusters=8, num_classes=7,
)
logits, cluster_weights = model(x) # logits: [N, 7]
loss = nn.CrossEntropyLoss()(logits, y_long)
Regression
model = SuperCluster(
input_dim=10, embedding_dim=64, num_clusters=5,
)
preds, cluster_weights = model(x) # preds: [N, 1], no sigmoid
loss = nn.MSELoss()(preds, y)
Reading prototype predictions
import torch.nn.functional as F
# Binary: prototype predicted probability
for k in range(model.num_clusters):
prob = torch.sigmoid(model.prototype_scores[k]).item()
print(f"Prototype {k}: P(y=1) = {prob:.3f}")
# Multi-class: prototype predicted class distribution
for k in range(model.num_clusters):
probs = F.softmax(model.prototype_scores[k], dim=0)
print(f"Prototype {k}: {probs.tolist()}")
How it works
SuperCluster separates two concerns that prior prototype models conflate:
| Parameter | Role |
|---|---|
model.centers ∈ ℝ^{K×d} |
Routing geometry — where prototypes live in embedding space |
model.prototype_scores ∈ ℝ^K or ℝ^{K×C} |
Prediction — each prototype's output value |
The center-diversity loss pushes centers apart on the unit hypersphere without touching prediction scores, preventing the "overconfident prototype" pathology common when routing geometry and prediction are conflated.
Predictive regime discovery
Set K larger than you expect and examine which prototypes are occupied. The model concentrates assignment mass on exactly as many prototypes as the data has distinct prediction behaviours — surplus prototypes collapse to zero routing mass. This gives an operational measure of how many prediction regimes the data contains.
from supercluster import effective_prototype_count
_, cw = model(X_test)
k_eff = effective_prototype_count(cw) # entropy-based count
active = model.active_prototypes(cw) # list of occupied prototype indices
print(f"K_eff = {k_eff:.2f}, K_active = {len(active)}/{model.num_clusters}")
Model parameters
| Parameter | Default | Description |
|---|---|---|
input_dim |
required | Raw feature dimension |
embedding_dim |
required | Latent dimension d |
num_clusters |
required | Number of prototypes K |
num_classes |
1 | 1 = binary/regression; C > 1 = multi-class |
encoder_layers |
4 | MLP encoder depth |
encoder_hidden_size |
256 | MLP encoder width |
num_attn_heads |
8 | Cross-attention heads (must divide embedding_dim) |
num_cross_attn_layers |
2 | Cross-attention depth L |
dropout |
0.1 | Dropout rate |
Training recommendations
| Setting | Recommendation |
|---|---|
| Optimizer | AdamW or Adam, lr = 1e-3 |
| Diversity weight | λ_div = 0.1 — add λ_div * center_diversity_loss(model.centers) to main loss |
| K | 1.5–3× your expected regime count; performance is flat across K |
| Patience | Early-stop on validation loss with patience 50–100 epochs |
Empirical results
| Dataset | MLP | SuperCluster | K_active |
|---|---|---|---|
| NBA shots (binary, 72k) | 63.3% | 63.4% | 3/8 |
| Bank Marketing (binary, 45k) | 88.9% / 78.7% AUC | 89.0% / 77.8% AUC | 3/8 |
| Adult Income (binary, 45k) | 85.0% / 91.0% AUC | 85.2% / 90.7% AUC | 2/8 |
| Credit Default (binary, 30k) | 81.7% / 76.8% AUC | 82.0% / 76.8% AUC | 2/8 |
| Covertype (7-class, 581k) | 95.7% | 95.5% | 6/8 |
Citation
If you use this package in academic work, please cite:
@article{danielson2026supercluster,
title = {SuperCluster: Learning Prediction Prototypes via
Target-Guided Cross-Attention Clustering},
author = {Danielson, Aaron John},
journal = {Data Mining and Knowledge Discovery},
year = {2026},
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_supercluster-0.1.0.tar.gz.
File metadata
- Download URL: torch_supercluster-0.1.0.tar.gz
- Upload date:
- Size: 12.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f7de8ffb7d892704012d738b7a63765d09df09a7aeb99685e67c9ff593826d16
|
|
| MD5 |
833b090998812458efe2474667dc31ad
|
|
| BLAKE2b-256 |
0528792c2258217aa18ee91ded400c85bcd109d9c55d126eed1aed710f27d064
|
File details
Details for the file torch_supercluster-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_supercluster-0.1.0-py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1a5e00920cf146d482de1b03e1ac39be494905db5f5446806d45c4a2f0214d06
|
|
| MD5 |
6e89c7ad5b9cccb982117771ecd0b626
|
|
| BLAKE2b-256 |
5f18408cfdd12765e1541785d33a88772daf5403848caf2f18feacd82b3eff00
|