Skip to main content

Tabular deep learning made simple: a scikit-learn compatible model zoo built on PyTorch and Lightning.

Project description

DeepTab: Tabular Deep Learning Made Simple

DeepTab is a Python library for deep learning on tabular data, built on PyTorch and Lightning with a scikit-learn compatible API. It offers 15 neural architectures, from Mamba-inspired state space models and Transformers to tree ensembles and MLP baselines, each available as a classifier, regressor, or distributional (LSS) model. One fit/predict/evaluate workflow covers everyday modeling, architecture research, and production deployment.

Why DeepTab?

  • Familiar interface. A scikit-learn fit/predict/evaluate API that drops into existing pipelines, including GridSearchCV.
  • Automatic preprocessing. Feature-type detection, encoding, scaling, and missing-value handling are powered by PreTab and applied for you.
  • One model, three tasks. Every architecture ships as a classifier, a regressor, and a distributional (LSS) variant for uncertainty quantification.
  • A broad model zoo. 15 stable architectures plus experimental models, all behind the same interface, with selection guidance.
  • Built for real data. Mixed feature types, class imbalance, GPU acceleration, and early stopping work out of the box.

⚡ What's New in v2.0

v2.0 is a ground-up restructuring of DeepTab. The high-level estimator API (MambularClassifier().fit(...)) is largely unchanged, but the internal package layout, configuration objects, and import paths have moved.

⚠️ Upgrading from v1? Packages were reorganised, the Default<Arch>Config classes were renamed to <Arch>Config, and the data modules were renamed to TabularDataModule / TabularDataset. Code that only uses the high-level estimators mostly keeps working; code that imported internal modules needs updating. See the FAQ for v1 support and upgrade notes.

Configuration and data

  • Split-config API: The model, preprocessing, and training each have their own configuration object, so you can tune one concern without disturbing the others. This is the first thing you reach for in v2.
  • Typed data layer: TabularDataset, TabularDataModule, and FeatureSchema give the data pipeline an explicit, inspectable contract, with stratified splitting controlled through TrainerConfig.

Models

  • New stable models: AutoInt, ENODE, and TabR.
  • New experimental models: Tangos, Trompt, and ModernNCA, under evaluation for promotion.

Training and evaluation

  • Observability and experiment tracking: ObservabilityConfig adds structured lifecycle logging via structlog and one-line MLflow or TensorBoard tracking, with every run saved to an organised directory tree. It is opt-in and silent by default.
  • Registry-driven training: Every torch.optim optimizer, learning-rate scheduler, and loss is selectable by name through TrainerConfig, and you can register your own at runtime.
  • Unified metrics: deeptab.metrics ships 25+ metric classes for regression, classification, and distributional models, auto-selected per task through a registry.
  • Reproducibility: set_seed and seed_context seed Python, NumPy, and PyTorch across CPU, CUDA, and MPS, including the DataLoader and sampler generators.

Deployment

  • Deployment-safe inference: InferenceModel wraps a fitted estimator in a read-only prediction surface with schema validation and task-type enforcement. Training methods are deliberately absent, so a served model cannot be re-fitted by accident.
  • Self-describing artifacts: save and load go through a single .deeptab format that bundles the architecture, feature schema, preprocessing, task type, and package versions alongside the weights, so a saved model carries everything needed to reload it.

Documentation

🏃 Quickstart

from deeptab.models import MambularClassifier

# Initialize and fit (sklearn-compatible)
model = MambularClassifier()
model.fit(X_train, y_train, max_epochs=50)

# Predict
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)

That's it! DeepTab handles preprocessing, batching, and training automatically.

Works with pandas & numpy: Pass DataFrames or arrays, and DeepTab auto-detects feature types.

Available Models

DeepTab provides 15 stable architectures across five families: State Space Models (Mambular, MambaTab, MambAttention), Transformers (FTTransformer, TabTransformer, SAINT, AutoInt), residual networks (ResNet, TabR), tree-inspired models (NODE, ENODE, NDTF), and general baselines (MLP, TabM, TabulaRNN). Three experimental models (ModernNCA, Tangos, Trompt) are under evaluation for promotion.

See the Model Zoo for detailed comparisons, complexity analysis, and selection guidance.

Stable Models

Category Model Architecture Best For
State Space Models Mambular Stacked Mamba over feature tokens General-purpose tabular modeling
MambaTab Lightweight Mamba SSM Small datasets and fast training
MambAttention Mamba with feature attention Feature-interaction-heavy data
Transformers FTTransformer Feature Tokenizer + Transformer Strong attention-based baseline
TabTransformer Transformer over categorical tokens Categorical-heavy data
SAINT Row and column attention Small or label-scarce datasets
AutoInt Self-attentive feature interactions Automatic high-order interactions
Residual Networks ResNet Residual MLP Fast dense baseline
TabR Retrieval-augmented MLP/kNN Large datasets with neighbor signal
Tree-Inspired NODE Neural oblivious decision ensembles Differentiable tree inductive bias
ENODE Embedded NODE-style soft trees Tree-inspired modeling with embeddings
NDTF Neural decision tree forest Differentiable forest experiments
Other MLP Feedforward dense network Fastest baseline
TabM Parameter-efficient ensemble MLP Strong efficient baseline
TabulaRNN Recurrent feature-sequence model Sequential feature modeling

Experimental Models ⚠️

⚠️ API Not Stable: Experimental models may change in minor releases. Always pin exact version: deeptab==x.y.z

  • ModernNCA: Neighborhood Component Analysis (metric learning)
  • Tangos: Gradient orthogonalization approach
  • Trompt: Prompt-based learning for tabular data

Task Variants

All models come in three variants:

  • *Classifier: Classification (binary & multi-class)
  • *Regressor: Regression (point estimates)
  • *LSS: Distributional regression (full distribution prediction)

Consistent API: All models use the same interface, so you can swap architectures without changing code.

📚 Documentation

Full documentation: deeptab.readthedocs.io

Quick Links

🛠️ Installation

Basic installation:

pip install deeptab

With experiment tracking and structured logging:

pip install 'deeptab[tracking]'   # MLflow + TensorBoard loggers
pip install 'deeptab[logs]'       # structured logging via structlog
pip install 'deeptab[all]'        # every optional backend

Faster Mamba models (optional CUDA kernels):

pip install mamba-ssm

Mamba kernels are optional: They give a 20-30% speedup for Mamba-based models on a compatible NVIDIA GPU (CUDA 11.6+). If the install fails or no GPU is present, DeepTab falls back to a pure-PyTorch implementation automatically.

Lightweight by default: Tracking backends are optional and imported lazily, so a plain pip install deeptab stays small. Install only the extras you actually use.

Requirements: Python 3.10+, PyTorch 2.2+, Lightning 2.3.3+

GPU Support: See installation guide for CUDA setup.

Usage

Basic Workflow

from deeptab.models import MambularClassifier
from deeptab.configs import MambularConfig, PreprocessingConfig, TrainerConfig

# 1. Initialize with configuration (optional - defaults work well!)
model_config = MambularConfig(d_model=64, n_layers=6)
prep_config = PreprocessingConfig(numerical_preprocessing="quantile")
trainer_config = TrainerConfig(lr=1e-4, batch_size=256)

model = MambularClassifier(
    model_config=model_config,
    preprocessing_config=prep_config,
    trainer_config=trainer_config
)

# 2. Fit (X can be pandas DataFrame or numpy array)
model.fit(X_train, y_train, max_epochs=50)

# 3. Predict
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)

# 4. Evaluate
metrics = model.evaluate(X_test, y_test)
# Regression:      {"rmse": …, "mae": …, "r2": …}
# Classification:  {"accuracy": …, "auroc": …, "log_loss": …}
# LSS (normal):    {"crps": …, "rmse": …, "mae": …}

💡 Tip: Start with defaults (MambularClassifier()) and tune only if needed. See Recommended Configs for guidance.

Hyperparameter Tuning

DeepTab models are sklearn-compatible, so you can use GridSearchCV:

from sklearn.model_selection import GridSearchCV
from deeptab.models import MambularClassifier

param_grid = {
    "model_config__d_model": [64, 128, 256],
    "model_config__n_layers": [4, 6, 8],
    "trainer_config__lr": [1e-4, 5e-4, 1e-3],
}

search = GridSearchCV(
    MambularClassifier(),
    param_grid,
    cv=5,
    scoring="accuracy"
)
search.fit(X_train, y_train)
print(f"Best params: {search.best_params_}")
print(f"Best score: {search.best_score_}")

Built-in HPO: Every estimator exposes optimize_hparams(), which runs Gaussian process Bayesian optimization (via scikit-optimize) over a search space derived from the model config. See the HPO Tutorial.

Distributional Regression (LSS)

Predict a full distribution instead of a single point estimate:

from deeptab.models import MambularLSS

# Choose a distribution family when you fit
model = MambularLSS()
model.fit(X_train, y_train, family="normal", max_epochs=50)

# predict() returns the estimated distribution parameters per sample
# (for "normal", that is the location and scale)
params = model.predict(X_test)

# Evaluate with proper scoring rules selected for the family
metrics = model.evaluate(X_test, y_test)

Available families: normal, lognormal, studentt, gamma, beta, tweedie, poisson, zip, negativebinom, dirichlet, mog, quantile, and more. Each family auto-selects appropriate evaluation metrics (CRPS, deviances, NLL).

Prediction intervals: Turn the predicted parameters into calibrated intervals as shown in the Uncertainty Quantification tutorial.

Advanced Features

Preprocessing

DeepTab includes comprehensive preprocessing powered by PreTab:

from deeptab.configs import PreprocessingConfig
from deeptab.models import MambularClassifier

prep_config = PreprocessingConfig(
    numerical_preprocessing="ple",  # Piecewise linear encoding
    n_bins=50                       # Number of bins for the encoding
)

model = MambularClassifier(preprocessing_config=prep_config)
model.fit(X_train, y_train, max_epochs=50)

Features:

  • Automatic detection: Feature types detected from data
  • Type-aware: Separate strategies for numerical and categorical features
  • Methods: PLE, quantile transform, splines, standardization, min-max, and robust scaling
  • Pre-trained encodings: Transfer learning for categorical features

Learn more: Preprocessing is driven by PreprocessingConfig; see the Config System guide and the PreTab project.

Observability & Experiment Tracking

DeepTab can record what happens during training without you writing any callbacks. Pass an ObservabilityConfig when you build a model, and each run captures its hyperparameters, lifecycle events, and final metrics in one self-contained folder.

from deeptab.core.observability import ObservabilityConfig
from deeptab.models import MambularClassifier

obs = ObservabilityConfig(
    experiment_name="churn_baseline",
    structured_logging=True,          # human-readable console + JSON event log
    experiment_trackers=["mlflow"],   # also supports "tensorboard"
)

model = MambularClassifier(observability_config=obs)
model.fit(X_train, y_train, max_epochs=50)

Every fit produces a tidy, reproducible run directory:

deeptab_runs/
  runs/churn_baseline/20260611_174830_8f3a2c/
    config.yaml       # estimator hyperparameters
    lifecycle.jsonl   # structured event log
    summary.json      # final metrics
    checkpoints/best.ckpt
  tensorboard/...
  mlflow/...

Tune the noise: verbosity controls how much is emitted (0 silent, 1 milestones, 2 detailed, 3 debug). The default keeps notebooks quiet.

🔬 For researchers: Lifecycle events such as fit.started, model.created, and train.completed carry structured metadata (sample counts, parameter counts, best validation loss), so you can script experiment sweeps and compare runs programmatically.

📖 Learn more: Observability

Custom Models

Implement your own architecture with DeepTab's base classes. A model is three small pieces: a dataclass config (subclassing BaseModelConfig), a PyTorch architecture (subclassing BaseModel), and one estimator per task that binds them via _model_cls / _config_cls:

from dataclasses import dataclass, field

import torch
import torch.nn as nn

from deeptab.configs import BaseModelConfig, TrainerConfig
from deeptab.core import BaseModel, get_feature_dimensions
from deeptab.models import SklearnBaseRegressor


@dataclass
class MyCustomConfig(BaseModelConfig):
    layer_sizes: list = field(default_factory=lambda: [128, 64])
    dropout: float = 0.1


class MyCustomModel(BaseModel):
    def __init__(
        self,
        feature_information: tuple,  # (num_info, cat_info, embedding_info)
        num_classes: int = 1,
        config: MyCustomConfig = MyCustomConfig(),  # noqa: B008
        **kwargs,
    ):
        super().__init__(config=config, **kwargs)
        self.save_hyperparameters(ignore=["feature_information"])

        # Input width is derived from the data, never hard-coded.
        input_dim = get_feature_dimensions(*feature_information)

        layers: list[nn.Module] = []
        prev = input_dim
        for size in self.hparams.layer_sizes:
            layers += [nn.Linear(prev, size), nn.ReLU(), nn.Dropout(self.hparams.dropout)]
            prev = size
        layers.append(nn.Linear(prev, num_classes))
        self.layers = nn.Sequential(*layers)

    def forward(self, *data) -> torch.Tensor:
        # data == (num_features, cat_features, embeddings)
        x = torch.cat([t for group in data for t in group], dim=1)
        return self.layers(x)


class MyRegressor(SklearnBaseRegressor):
    _model_cls = MyCustomModel
    _config_cls = MyCustomConfig


# Use like any other DeepTab model
model = MyRegressor(
    model_config=MyCustomConfig(layer_sizes=[256, 128]),
    trainer_config=TrainerConfig(lr=1e-3),
)
model.fit(X_train, y_train, max_epochs=50)

📖 Learn more: Custom Models walks through configs, embeddings, and the *Classifier / *Regressor / *LSS variants.

🛠️ Developer Guide: See Contributing for architecture guidelines.

🏷️ Citation

If you use DeepTab in your research, please cite:

@article{thielmann2024mambular,
  title={Mambular: A Sequential Model for Tabular Deep Learning},
  author={Thielmann, Anton Frederik and Kumar, Manish and Weisser, Christoph and Reuter, Arik and S{\"a}fken, Benjamin and Samiee, Soheila},
  journal={arXiv preprint arXiv:2408.06291},
  year={2024}
}

@article{thielmann2024efficiency,
  title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning},
  author={Thielmann, Anton Frederik and Samiee, Soheila},
  journal={arXiv preprint arXiv:2411.17207},
  year={2024}
}

📄 License

DeepTab is licensed under the MIT License. See LICENSE for details.

🤝 Contributing

Contributions are welcome. See the Contributing Guide to get started, and please follow our Code of Conduct.

📞 Support

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deeptab-2.0.0.tar.gz (204.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deeptab-2.0.0-py3-none-any.whl (273.8 kB view details)

Uploaded Python 3

File details

Details for the file deeptab-2.0.0.tar.gz.

File metadata

  • Download URL: deeptab-2.0.0.tar.gz
  • Upload date:
  • Size: 204.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for deeptab-2.0.0.tar.gz
Algorithm Hash digest
SHA256 2681791ff046ddf3e9fe355b4ee95aa3d6835c6d02bb04e70dd0c4a321302db8
MD5 275523c579bebc6a8d30adbb9c453cd5
BLAKE2b-256 757227aea767e6b6086190de7e02afb6a5d6182b551e4f3173dd4169a88f9043

See more details on using hashes here.

Provenance

The following attestation bundles were made for deeptab-2.0.0.tar.gz:

Publisher: publish-pypi.yml on OpenTabular/DeepTab

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file deeptab-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: deeptab-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 273.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for deeptab-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 312b40c3628e5b601f48a8c51f44a0f197d1f7a4b9491f6551df78379a0e390c
MD5 7966f36eb0ea1e9452d9c0bfeaab6c56
BLAKE2b-256 8c8d515aadd68ce8822d388c462cf73863fb6f5fece102e9d48265f031cde114

See more details on using hashes here.

Provenance

The following attestation bundles were made for deeptab-2.0.0-py3-none-any.whl:

Publisher: publish-pypi.yml on OpenTabular/DeepTab

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page