Skip to main content

Extensible tabular deep learning: TabularResNet, DANet, and TabularLNN with first-class sample_weight, custom objectives, and custom metrics.

Project description

masaMLP

CI PyPI License: MIT

Extensible tabular deep learning — TabularResNet, DANet, and TabularLNN behind sklearn-compatible estimators with first-class sample_weight, custom objectives, custom metrics, and early stopping on any metric. The sibling library of repleafgbm (same author, same API philosophy), for the neural side of tabular ML.

Status: alpha (0.2.x). Built with heavy use of Claude Code (coding and architecture design).

Why masaMLP

Excellent tabular DL libraries exist — pytabkit ships state-of-the-art models like RealMLP and TabM, and rtdl provides reference modules. What they don't make easy is extension: sample_weight in fit, custom training losses, custom evaluation metrics, and early stopping driven by them. masaMLP is built around exactly those hooks:

  • fit(X, y, sample_weight=..., eval_set=...) — LightGBM-style, sklearn compatible. Weights flow through a single reduction (loss * w).sum() / w.sum() that every objective shares.
  • Custom objectives are per-sample torch losses — a plain function (or nn.Module with trainable parameters). Because the trainer owns the weighted reduction, your loss gets correct sample_weight and class_weight handling for free.
  • Custom metrics are plain NumPy callables via make_metric, and any of them (minimize or maximize) can drive early stopping with best-epoch weight restoration.
  • Multiclass, multioutput regression, class_weight, label smoothing supported natively; built-in preprocessing (quantile scaling, missing values, categorical embeddings) so DataFrames go straight into fit.
  • CPU / CUDA / MPS / multi-GPU behind device="auto": device-resident tensors with no DataLoader overhead, automatic full-batch mode for small data, per-model bf16 AMP on CUDA, opt-in torch.compile with eager fallback — and when several GPUs are detected, n_ens members train concurrently, one worker per GPU.

masaMLP deliberately does not try to re-benchmark the field — see docs/attribution.md for the research and libraries it builds on.

Models

name source notes
resnet Gorishniy et al. 2021 (arXiv:2106.11959) default; strong baseline
realmlp Holzmüller et al. 2024 (arXiv:2407.04491) RealMLP-TD-S architecture (scaling layer, NTP linear layers, SELU/Mish); pair with masamlp.realmlp_params(task) for the full training recipe
ft_transformer Gorishniy et al. 2021 (arXiv:2106.11959) feature tokens + [CLS] + PreNorm/ReGLU transformer, per the rtdl reference
tab_transformer Huang et al. 2020 (arXiv:2012.06678) transformer over categorical tokens; numerics bypass (or embed via num_embedding)
danet Chen et al. AAAI 2022 (arXiv:2112.02962) Abstract Layers with learnable sparse feature groups (in-house entmax15)
tabr Gorishniy et al. 2023 (arXiv:2307.14338) retrieval-augmented: nearest training rows are aggregated into each prediction
modernnca Ye et al. 2024 (arXiv:2407.03257) soft-nearest-neighbor aggregation with stochastic candidate sampling; pairs well with num_embedding="plr-lite"
gandalf Joseph & Raj 2022 (arXiv:2207.08548) GFLU stages: learnable sparse feature masks (t-softmax) with GRU-style gating; exposes feature_importances()
grn GRN blocks from TFT, Lim et al. 2021 (arXiv:1912.09363) stack of Gated Residual Networks over embedded features (masaMLP's own composition)
lnn CfC cells, Hasani et al. 2022 experimental liquid-network adaptation for static tabular data — see docs/lnn.md

Third-party architectures plug in with register_model and get the whole estimator surface (weights, objectives, metrics, early stopping) for free.

RealMLP insights are composable options

The tricks from the RealMLP paper are estimator-level options usable with any model (lnn included), not baked into one architecture:

  • numeric_scaler="rssc" — robust scale + smooth clip preprocessing
  • cat_encoding="onehot" — RealMLP-style one-hot (binary → ±1, missing → 0)
  • num_embedding="pbld" | "plr" | "plr-lite" | "pl" | "periodic" — the numeric embedding zoo (arXiv:2203.05556 + PBLD); token models (ft_transformer, tab_transformer) use the same options as feature tokenizers
  • model_params={"num_scaling": True} — learnable per-feature input scale
  • lr_scheduler="coslog4", optimizer_betas=(0.9, 0.95) — the training schedule
  • clip_predictions=True (regressor) — clip to the observed target range
  • n_ens=k — seed ensembling as in pytabkit's RealMLP: k members trained with seeds random_state + i, predictions averaged on the probability / value scale; works with every model including the retrieval ones. ens_mode="vectorized" trains all members in one vmapped forward/backward (torch.func) for BatchNorm-free models — pytabkit's speed trick
  • weight_decay_schedule="flat_cos" — RealMLP-TD's scheduled weight decay (param groups can opt out, e.g. biases)
  • ema_decay=0.999 — exponential moving average (Polyak averaging) of the weights; evaluation, early stopping, and the final model all use the averaged parameters
  • candidate_budget=N — bound the retrieval corpus of tabr/modernnca with a seeded, class-stratified subsample (keeps memory/compute in check on large data; no-op for other models)
  • masamlp.realmlp_td_params(task) — the full RealMLP-TD recipe: parametric activations, flat_cos-scheduled dropout and weight decay, PBLD embeddings with their own lr factor, and hybrid categorical encoding (one-hot ≤ 9 categories, embeddings above)
from masamlp import MasaClassifier, realmlp_params

clf = MasaClassifier(**realmlp_params("classification"))    # the TD-S recipe
clf = MasaClassifier(**{**realmlp_params("classification"),
                        "num_embedding": "pbld"})           # toward RealMLP-TD

Install

pip install masamlp        # torch, numpy, pandas, scikit-learn

Quickstart

import numpy as np
from masamlp import MasaClassifier, make_metric

def f1(y_true, y_proba):
    pred = y_proba >= 0.5
    tp = np.sum(pred & (y_true == 1))
    return 2 * tp / max(pred.sum() + (y_true == 1).sum(), 1)

clf = MasaClassifier(
    model="resnet",
    eval_metric=make_metric(f1, name="f1", minimize=False),
    early_stopping_rounds=15,
    class_weight="balanced",
)
clf.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)])
proba = clf.predict_proba(X_test)
print(clf.best_iteration_, clf.best_score_, clf.evals_result_["valid_0"]["f1"][:3])

Custom objective (regression, asymmetric loss):

import torch
from masamlp import MasaRegressor

def asymmetric_mse(y_true, raw_pred):          # -> per-sample (n,) tensor
    err = raw_pred - y_true                    # raw_pred: (n, out_dim)
    return torch.where(err < 0, 4.0 * err**2, err**2).mean(dim=1)

reg = MasaRegressor(model="danet", objective=asymmetric_mse)
reg.fit(X, y, sample_weight=w)                 # weights just work

Save/load is a plain directory (manifest.json + tensors, loaded with weights_only=True — no pickle execution):

reg.save_model("model_dir")
reg2 = MasaRegressor.load_model("model_dir")

Devices

device="auto" resolves cuda > mps > cpu. CUDA gets bf16 AMP by default (per-model: the retrieval models opt out) and optional compile=True; MPS and CPU train in float32. With multiple GPUs and n_ens > 1, ensemble members are sharded across all GPUs and trained concurrently; opt out with device="cuda:0". Details and caveats: docs/devices.md.

Development

pip install -e ".[dev]"
bash scripts/check.sh      # ruff + pytest + examples/quickstart.py

Development rules live in CLAUDE.md; roadmap in docs/roadmap.md.

License

MIT. Architecture attributions: docs/attribution.md.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

masamlp-0.3.0.tar.gz (85.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

masamlp-0.3.0-py3-none-any.whl (77.6 kB view details)

Uploaded Python 3

File details

Details for the file masamlp-0.3.0.tar.gz.

File metadata

  • Download URL: masamlp-0.3.0.tar.gz
  • Upload date:
  • Size: 85.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for masamlp-0.3.0.tar.gz
Algorithm Hash digest
SHA256 a5a134e9d43bc10a499a2eedadd9dd8921276c653caae750b9c00e24ea116e02
MD5 391dada48810108c64abc276e66abd87
BLAKE2b-256 f5d1a7f69542ffab3bfbbffe45e140463b307ce7ba8700031bdcc125088d8e2e

See more details on using hashes here.

Provenance

The following attestation bundles were made for masamlp-0.3.0.tar.gz:

Publisher: publish.yml on Matapanino/masamlp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file masamlp-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: masamlp-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 77.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for masamlp-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c4fac3a50f9b2880d65ae4790fbc4b63b683372ff14d2cf0214e5e3f1df062b9
MD5 0c50ab2eb0569fe5114c2295cb206242
BLAKE2b-256 ed912f75ae04cdb1cce0a2b3adea5ebcd815d4498db28b1578ce32910b5a07f5

See more details on using hashes here.

Provenance

The following attestation bundles were made for masamlp-0.3.0-py3-none-any.whl:

Publisher: publish.yml on Matapanino/masamlp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page