Extensible tabular deep learning: TabularResNet, DANet, and TabularLNN with first-class sample_weight, custom objectives, and custom metrics.
Project description
masaMLP
Extensible tabular deep learning — TabularResNet, DANet, and TabularLNN behind sklearn-compatible estimators with first-class sample_weight, custom objectives, custom metrics, and early stopping on any metric. The sibling library of repleafgbm (same author, same API philosophy), for the neural side of tabular ML.
Status: alpha (0.2.x). Built with heavy use of Claude Code (coding and architecture design).
Why masaMLP
Excellent tabular DL libraries exist — pytabkit
ships state-of-the-art models like RealMLP and TabM, and
rtdl provides reference modules.
What they don't make easy is extension: sample_weight in fit, custom
training losses, custom evaluation metrics, and early stopping driven by
them. masaMLP is built around exactly those hooks:
fit(X, y, sample_weight=..., eval_set=...)— LightGBM-style, sklearn compatible. Weights flow through a single reduction(loss * w).sum() / w.sum()that every objective shares.- Custom objectives are per-sample torch losses — a plain function (or
nn.Modulewith trainable parameters). Because the trainer owns the weighted reduction, your loss gets correctsample_weightandclass_weighthandling for free. - Custom metrics are plain NumPy callables via
make_metric, and any of them (minimize or maximize) can drive early stopping with best-epoch weight restoration. - Multiclass, multioutput regression, class_weight, label smoothing
supported natively; built-in preprocessing (quantile scaling, missing
values, categorical embeddings) so DataFrames go straight into
fit. - CPU / CUDA / MPS behind
device="auto": device-resident tensors with no DataLoader overhead, automatic full-batch mode for small data, bf16 AMP on CUDA, opt-intorch.compilewith eager fallback.
masaMLP deliberately does not try to re-benchmark the field — see docs/attribution.md for the research and libraries it builds on.
Models
| name | source | notes |
|---|---|---|
resnet |
Gorishniy et al. 2021 (arXiv:2106.11959) | default; strong baseline |
realmlp |
Holzmüller et al. 2024 (arXiv:2407.04491) | RealMLP-TD-S architecture (scaling layer, NTP linear layers, SELU/Mish); pair with masamlp.realmlp_params(task) for the full training recipe |
ft_transformer |
Gorishniy et al. 2021 (arXiv:2106.11959) | feature tokens + [CLS] + PreNorm/ReGLU transformer, per the rtdl reference |
tab_transformer |
Huang et al. 2020 (arXiv:2012.06678) | transformer over categorical tokens; numerics bypass (or embed via num_embedding) |
danet |
Chen et al. AAAI 2022 (arXiv:2112.02962) | Abstract Layers with learnable sparse feature groups (in-house entmax15) |
tabr |
Gorishniy et al. 2023 (arXiv:2307.14338) | retrieval-augmented: nearest training rows are aggregated into each prediction |
modernnca |
Ye et al. 2024 (arXiv:2407.03257) | soft-nearest-neighbor aggregation with stochastic candidate sampling; pairs well with num_embedding="plr-lite" |
gandalf |
Joseph & Raj 2022 (arXiv:2207.08548) | GFLU stages: learnable sparse feature masks (t-softmax) with GRU-style gating; exposes feature_importances() |
grn |
GRN blocks from TFT, Lim et al. 2021 (arXiv:1912.09363) | stack of Gated Residual Networks over embedded features (masaMLP's own composition) |
lnn |
CfC cells, Hasani et al. 2022 | experimental liquid-network adaptation for static tabular data — see docs/lnn.md |
Third-party architectures plug in with register_model and get the whole
estimator surface (weights, objectives, metrics, early stopping) for free.
RealMLP insights are composable options
The tricks from the RealMLP paper are estimator-level options usable with
any model (lnn included), not baked into one architecture:
numeric_scaler="rssc"— robust scale + smooth clip preprocessingcat_encoding="onehot"— RealMLP-style one-hot (binary → ±1, missing → 0)num_embedding="pbld" | "plr" | "plr-lite" | "pl" | "periodic"— the numeric embedding zoo (arXiv:2203.05556 + PBLD); token models (ft_transformer,tab_transformer) use the same options as feature tokenizersmodel_params={"num_scaling": True}— learnable per-feature input scalelr_scheduler="coslog4",optimizer_betas=(0.9, 0.95)— the training scheduleclip_predictions=True(regressor) — clip to the observed target rangen_ens=k— seed ensembling as in pytabkit's RealMLP: k members trained with seedsrandom_state + i, predictions averaged on the probability / value scale; works with every model including the retrieval ones.ens_mode="vectorized"trains all members in one vmapped forward/backward (torch.func) for BatchNorm-free models — pytabkit's speed trickweight_decay_schedule="flat_cos"— RealMLP-TD's scheduled weight decay (param groups can opt out, e.g. biases)ema_decay=0.999— exponential moving average (Polyak averaging) of the weights; evaluation, early stopping, and the final model all use the averaged parameterscandidate_budget=N— bound the retrieval corpus oftabr/modernncawith a seeded, class-stratified subsample (keeps memory/compute in check on large data; no-op for other models)masamlp.realmlp_td_params(task)— the full RealMLP-TD recipe: parametric activations, flat_cos-scheduled dropout and weight decay, PBLD embeddings with their own lr factor, and hybrid categorical encoding (one-hot ≤ 9 categories, embeddings above)
from masamlp import MasaClassifier, realmlp_params
clf = MasaClassifier(**realmlp_params("classification")) # the TD-S recipe
clf = MasaClassifier(**{**realmlp_params("classification"),
"num_embedding": "pbld"}) # toward RealMLP-TD
Install
pip install masamlp # torch, numpy, pandas, scikit-learn
Quickstart
import numpy as np
from masamlp import MasaClassifier, make_metric
def f1(y_true, y_proba):
pred = y_proba >= 0.5
tp = np.sum(pred & (y_true == 1))
return 2 * tp / max(pred.sum() + (y_true == 1).sum(), 1)
clf = MasaClassifier(
model="resnet",
eval_metric=make_metric(f1, name="f1", minimize=False),
early_stopping_rounds=15,
class_weight="balanced",
)
clf.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)])
proba = clf.predict_proba(X_test)
print(clf.best_iteration_, clf.best_score_, clf.evals_result_["valid_0"]["f1"][:3])
Custom objective (regression, asymmetric loss):
import torch
from masamlp import MasaRegressor
def asymmetric_mse(y_true, raw_pred): # -> per-sample (n,) tensor
err = raw_pred - y_true # raw_pred: (n, out_dim)
return torch.where(err < 0, 4.0 * err**2, err**2).mean(dim=1)
reg = MasaRegressor(model="danet", objective=asymmetric_mse)
reg.fit(X, y, sample_weight=w) # weights just work
Save/load is a plain directory (manifest.json + tensors, loaded with
weights_only=True — no pickle execution):
reg.save_model("model_dir")
reg2 = MasaRegressor.load_model("model_dir")
Devices
device="auto" resolves cuda > mps > cpu. CUDA gets bf16 AMP by default and
optional compile=True; MPS and CPU train in float32. Details and caveats:
docs/devices.md.
Development
pip install -e ".[dev]"
bash scripts/check.sh # ruff + pytest + examples/quickstart.py
Development rules live in CLAUDE.md; roadmap in docs/roadmap.md.
License
MIT. Architecture attributions: docs/attribution.md.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file masamlp-0.2.0.tar.gz.
File metadata
- Download URL: masamlp-0.2.0.tar.gz
- Upload date:
- Size: 74.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
20763b4e6bc2718816bd0df5accefa51f89fe5625ec67360a6c2ad41430bd440
|
|
| MD5 |
da28c3590996b7c1158acadb470fa36e
|
|
| BLAKE2b-256 |
0eccfa56f6703907ebdf66c9c4234ab19ef1129a9934e0fc75896c50ba0e078d
|
Provenance
The following attestation bundles were made for masamlp-0.2.0.tar.gz:
Publisher:
publish.yml on Matapanino/masamlp
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
masamlp-0.2.0.tar.gz -
Subject digest:
20763b4e6bc2718816bd0df5accefa51f89fe5625ec67360a6c2ad41430bd440 - Sigstore transparency entry: 2052942682
- Sigstore integration time:
-
Permalink:
Matapanino/masamlp@0ba4667e851f43511e2d7dcb3fb24a417c3d8961 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/Matapanino
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0ba4667e851f43511e2d7dcb3fb24a417c3d8961 -
Trigger Event:
push
-
Statement type:
File details
Details for the file masamlp-0.2.0-py3-none-any.whl.
File metadata
- Download URL: masamlp-0.2.0-py3-none-any.whl
- Upload date:
- Size: 69.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
360d5d7e3b6f921052ffb312bf2bdda7be54ccfd260ef2ab4b2885c7a8a8fc60
|
|
| MD5 |
edc6ec36c03fe01ac04a1c033e1f0523
|
|
| BLAKE2b-256 |
535110877ed4a3f723e4da17c241d27d471d68115f95f7ffeb26627215854f74
|
Provenance
The following attestation bundles were made for masamlp-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on Matapanino/masamlp
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
masamlp-0.2.0-py3-none-any.whl -
Subject digest:
360d5d7e3b6f921052ffb312bf2bdda7be54ccfd260ef2ab4b2885c7a8a8fc60 - Sigstore transparency entry: 2052943279
- Sigstore integration time:
-
Permalink:
Matapanino/masamlp@0ba4667e851f43511e2d7dcb3fb24a417c3d8961 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/Matapanino
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@0ba4667e851f43511e2d7dcb3fb24a417c3d8961 -
Trigger Event:
push
-
Statement type: