Skip to main content

Extensible tabular deep learning: TabularResNet, DANet, and TabularLNN with first-class sample_weight, custom objectives, and custom metrics.

Project description

masaMLP

CI PyPI License: MIT

Extensible tabular deep learning — TabularResNet, DANet, and TabularLNN behind sklearn-compatible estimators with first-class sample_weight, custom objectives, custom metrics, and early stopping on any metric. The sibling library of repleafgbm (same author, same API philosophy), for the neural side of tabular ML.

Status: alpha (0.1.x). Built with heavy use of Claude Code (coding and architecture design).

Why masaMLP

Excellent tabular DL libraries exist — pytabkit ships state-of-the-art models like RealMLP and TabM, and rtdl provides reference modules. What they don't make easy is extension: sample_weight in fit, custom training losses, custom evaluation metrics, and early stopping driven by them. masaMLP is built around exactly those hooks:

  • fit(X, y, sample_weight=..., eval_set=...) — LightGBM-style, sklearn compatible. Weights flow through a single reduction (loss * w).sum() / w.sum() that every objective shares.
  • Custom objectives are per-sample torch losses — a plain function (or nn.Module with trainable parameters). Because the trainer owns the weighted reduction, your loss gets correct sample_weight and class_weight handling for free.
  • Custom metrics are plain NumPy callables via make_metric, and any of them (minimize or maximize) can drive early stopping with best-epoch weight restoration.
  • Multiclass, multioutput regression, class_weight, label smoothing supported natively; built-in preprocessing (quantile scaling, missing values, categorical embeddings) so DataFrames go straight into fit.
  • CPU / CUDA / MPS behind device="auto": device-resident tensors with no DataLoader overhead, automatic full-batch mode for small data, bf16 AMP on CUDA, opt-in torch.compile with eager fallback.

masaMLP deliberately does not try to re-benchmark the field — see docs/attribution.md for the research and libraries it builds on.

Models

name source notes
resnet Gorishniy et al. 2021 (arXiv:2106.11959) default; strong baseline
realmlp Holzmüller et al. 2024 (arXiv:2407.04491) RealMLP-TD-S architecture (scaling layer, NTP linear layers, SELU/Mish); pair with masamlp.realmlp_params(task) for the full training recipe
ft_transformer Gorishniy et al. 2021 (arXiv:2106.11959) feature tokens + [CLS] + PreNorm/ReGLU transformer, per the rtdl reference
tab_transformer Huang et al. 2020 (arXiv:2012.06678) transformer over categorical tokens; numerics bypass (or embed via num_embedding)
danet Chen et al. AAAI 2022 (arXiv:2112.02962) Abstract Layers with learnable sparse feature groups (in-house entmax15)
tabr Gorishniy et al. 2023 (arXiv:2307.14338) retrieval-augmented: nearest training rows are aggregated into each prediction
modernnca Ye et al. 2024 (arXiv:2407.03257) soft-nearest-neighbor aggregation with stochastic candidate sampling; pairs well with num_embedding="plr-lite"
gandalf Joseph & Raj 2022 (arXiv:2207.08548) GFLU stages: learnable sparse feature masks (t-softmax) with GRU-style gating; exposes feature_importances()
grn GRN blocks from TFT, Lim et al. 2021 (arXiv:1912.09363) stack of Gated Residual Networks over embedded features (masaMLP's own composition)
lnn CfC cells, Hasani et al. 2022 experimental liquid-network adaptation for static tabular data — see docs/lnn.md

Third-party architectures plug in with register_model and get the whole estimator surface (weights, objectives, metrics, early stopping) for free.

RealMLP insights are composable options

The tricks from the RealMLP paper are estimator-level options usable with any model (lnn included), not baked into one architecture:

  • numeric_scaler="rssc" — robust scale + smooth clip preprocessing
  • cat_encoding="onehot" — RealMLP-style one-hot (binary → ±1, missing → 0)
  • num_embedding="pbld" | "plr" | "plr-lite" | "pl" | "periodic" — the numeric embedding zoo (arXiv:2203.05556 + PBLD); token models (ft_transformer, tab_transformer) use the same options as feature tokenizers
  • model_params={"num_scaling": True} — learnable per-feature input scale
  • lr_scheduler="coslog4", optimizer_betas=(0.9, 0.95) — the training schedule
  • clip_predictions=True (regressor) — clip to the observed target range
  • n_ens=k — seed ensembling as in pytabkit's RealMLP: k members trained with seeds random_state + i, predictions averaged on the probability / value scale; works with every model including the retrieval ones. ens_mode="vectorized" trains all members in one vmapped forward/backward (torch.func) for BatchNorm-free models — pytabkit's speed trick
  • weight_decay_schedule="flat_cos" — RealMLP-TD's scheduled weight decay (param groups can opt out, e.g. biases)
  • masamlp.realmlp_td_params(task) — the full RealMLP-TD recipe: parametric activations, flat_cos-scheduled dropout and weight decay, PBLD embeddings with their own lr factor, and hybrid categorical encoding (one-hot ≤ 9 categories, embeddings above)
from masamlp import MasaClassifier, realmlp_params

clf = MasaClassifier(**realmlp_params("classification"))    # the TD-S recipe
clf = MasaClassifier(**{**realmlp_params("classification"),
                        "num_embedding": "pbld"})           # toward RealMLP-TD

Install

pip install masamlp        # torch, numpy, pandas, scikit-learn

Quickstart

import numpy as np
from masamlp import MasaClassifier, make_metric

def f1(y_true, y_proba):
    pred = y_proba >= 0.5
    tp = np.sum(pred & (y_true == 1))
    return 2 * tp / max(pred.sum() + (y_true == 1).sum(), 1)

clf = MasaClassifier(
    model="resnet",
    eval_metric=make_metric(f1, name="f1", minimize=False),
    early_stopping_rounds=15,
    class_weight="balanced",
)
clf.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)])
proba = clf.predict_proba(X_test)
print(clf.best_iteration_, clf.best_score_, clf.evals_result_["valid_0"]["f1"][:3])

Custom objective (regression, asymmetric loss):

import torch
from masamlp import MasaRegressor

def asymmetric_mse(y_true, raw_pred):          # -> per-sample (n,) tensor
    err = raw_pred - y_true                    # raw_pred: (n, out_dim)
    return torch.where(err < 0, 4.0 * err**2, err**2).mean(dim=1)

reg = MasaRegressor(model="danet", objective=asymmetric_mse)
reg.fit(X, y, sample_weight=w)                 # weights just work

Save/load is a plain directory (manifest.json + tensors, loaded with weights_only=True — no pickle execution):

reg.save_model("model_dir")
reg2 = MasaRegressor.load_model("model_dir")

Devices

device="auto" resolves cuda > mps > cpu. CUDA gets bf16 AMP by default and optional compile=True; MPS and CPU train in float32. Details and caveats: docs/devices.md.

Development

pip install -e ".[dev]"
bash scripts/check.sh      # ruff + pytest + examples/quickstart.py

Development rules live in CLAUDE.md; roadmap in docs/roadmap.md.

License

MIT. Architecture attributions: docs/attribution.md.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

masamlp-0.1.0.tar.gz (69.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

masamlp-0.1.0-py3-none-any.whl (67.5 kB view details)

Uploaded Python 3

File details

Details for the file masamlp-0.1.0.tar.gz.

File metadata

  • Download URL: masamlp-0.1.0.tar.gz
  • Upload date:
  • Size: 69.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for masamlp-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a55e8b3fe75edc4004236c679d87f19e42f7da01ae861ba5c5a1ae77d52fce40
MD5 c0811e9e4c857b3d8744c64480d405df
BLAKE2b-256 c5fd8b3a8967c2fd277ede3a574648a2ab516f44a66d1af7a635cce0bdb5807a

See more details on using hashes here.

Provenance

The following attestation bundles were made for masamlp-0.1.0.tar.gz:

Publisher: publish.yml on Matapanino/masamlp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file masamlp-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: masamlp-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 67.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for masamlp-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f544c5819afbfb9bb3ad05b2020cfd1dd01390b79e04a4536cb076a67e586945
MD5 2fa83d2ed0365db46c0f58cfbe1efd73
BLAKE2b-256 5585f29223fda4dff2ae43941c1584fa88cc944e8ebe67a39f51636caeee3ec4

See more details on using hashes here.

Provenance

The following attestation bundles were made for masamlp-0.1.0-py3-none-any.whl:

Publisher: publish.yml on Matapanino/masamlp

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page