Skip to main content

ML models + benchmark for tabular data classification and regression

Project description

Open In Colab test

PyTabKit: Tabular ML models and benchmarking (NeurIPS 2024)

Paper Documentation RealMLP-TD-S standalone implementation Grinsztajn et al. benchmark code Data archive

PyTabKit provides scikit-learn interfaces for modern tabular classification and regression methods benchmarked in our paper, see below. It also contains the code we used for benchmarking these methods on our benchmarks.

Meta-test benchmark results

Installation

pip install pytabkit
  • If you want to use TabR, you have to manually install faiss, which is only available on conda.
  • Please install torch separately if you want to control the version (CPU/GPU etc.)
  • Use pytabkit[autogluon,extra,hpo,bench,dev] to install additional dependencies for AutoGluon models, extra preprocessing, hyperparameter optimization methods beyond random search (hyperopt/SMAC), the benchmarking part, and testing/documentation. For the hpo part, you might need to install swig (e.g. via pip) if the build of pyrfr fails. See also the documentation. To run the data download, you need one of rar, unrar, or 7-zip to be installed on the system.

Using the ML models

Most of our machine learning models are directly available via scikit-learn interfaces. For example, you can use RealMLP-TD for classification as follows:

from pytabkit import RealMLP_TD_Classifier

model = RealMLP_TD_Classifier()  # or TabR_S_D_Classifier, CatBoost_TD_Classifier, etc.
model.fit(X_train, y_train)
model.predict(X_test)

The code above will automatically select a GPU if available, try to detect categorical columns in dataframes, preprocess numerical variables and regression targets (no standardization required), and use a training-validation split for early stopping. All of this (and much more) can be configured through the constructor and the parameters of the fit() method. For example, it is possible to do bagging (ensembling of models on 5-fold cross-validation) simply by passing n_cv=5 to the constructor. Here is an example for some of the parameters that can be set explicitly:

from pytabkit import RealMLP_TD_Classifier

model = RealMLP_TD_Classifier(device='cpu', random_state=0, n_cv=1, n_refit=0,
                              n_epochs=256, batch_size=256, hidden_sizes=[256] * 3,
                              val_metric_name='cross_entropy',
                              use_ls=False,  # for metrics like AUC / log-loss
                              lr=0.04, verbosity=2)
model.fit(X_train, y_train, X_val, y_val, cat_col_names=['Education'])
model.predict_proba(X_test)

See this notebook for more examples. Missing numerical values are currently not allowed and need to be imputed beforehand.

Available ML models

Our ML models are available in up to three variants, all with best-epoch selection:

  • library defaults (D)
  • our tuned defaults (TD)
  • random search hyperparameter optimization (HPO), sometimes also tree parzen estimator (HPO-TPE)

We provide the following ML models:

Benchmarking code

Our benchmarking code has functionality for

  • dataset download
  • running methods highly parallel on single-node/multi-node/multi-GPU hardware, with automatic scheduling and trying to respect RAM constraints
  • analyzing/plotting results

For more details, we refer to the documentation.

Preprocessing code

While many preprocessing methods are implemented in this repository, a standalone version of our robust scaling + smooth clipping can be found here.

Citation

If you use this repository for research purposes, please cite our paper:

@inproceedings{holzmuller2024better,
  title={Better by default: {S}trong pre-tuned {MLPs} and boosted trees on tabular data},
  author={Holzm{\"u}ller, David and Grinsztajn, Leo and Steinwart, Ingo},
  booktitle = {Neural {Information} {Processing} {Systems}},
  year={2024}
}

Contributors

  • David Holzmüller (main developer)
  • Léo Grinsztajn (deep learning baselines, plotting)
  • Ingo Steinwart (UCI dataset download)
  • Katharina Strecker (PyTorch-Lightning interface)
  • Lennart Purucker (some features/fixes)
  • Jérôme Dockès (deployment, continuous integration)

Acknowledgements

Code from other repositories is acknowledged as well as possible in code comments. Especially, we used code from https://github.com/yandex-research/rtdl and sub-packages (Apache 2.0 license), code from https://github.com/catboost/benchmarks/ (Apache 2.0 license), and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html (Apache 2.0 license).

Releases (see git tags)

  • v1.1.0:
    • Included TabM
    • Replaced __ by _ in parameter names for MLP, MLP-PLR, ResNet, and FTT, to comply with scikit-learn interface requirements.
    • Fixed non-determinism in NN baselines by initializing the random state of quantile (and KDI) preprocessing transforms.
    • n_threads parameter is not ignored by NNs anymore.
    • Changes by Lennart Purucker: Add time limit for RealMLP, add support for lightning (but also still allowing pytorch-lightning), making skorch a lazy import, removed msgpack_numpy dependency.
  • v1.0.0: Release for the NeurIPS version and arXiv v2.
    • More baselines (MLP-PLR, FT-Transformer, TabR-HPO, RF-HPO), also some un-polished internal interfaces for other methods, esp. the ones in AutoGluon.
    • Updated benchmarking code (configurations, plots) including the new version of the Grinsztajn et al. benchmark
    • Updated fit() parameters in scikit-learn interfaces, etc.
  • v0.0.1: First release for arXiv v1. Code and data are archived at DaRUS.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytabkit-1.1.0.tar.gz (274.0 kB view details)

Uploaded Source

Built Distribution

pytabkit-1.1.0-py3-none-any.whl (321.7 kB view details)

Uploaded Python 3

File details

Details for the file pytabkit-1.1.0.tar.gz.

File metadata

  • Download URL: pytabkit-1.1.0.tar.gz
  • Upload date:
  • Size: 274.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for pytabkit-1.1.0.tar.gz
Algorithm Hash digest
SHA256 d59d290f4bd02cae999ff5bce50d385bbd2941a9a97ddc829d39f270f5fd496a
MD5 98c8bb97cd4de270f46be753fa6f8d55
BLAKE2b-256 7c23251b71c2a24cee49e1de4a2d203c617faf6869656958815772440c8fc976

See more details on using hashes here.

File details

Details for the file pytabkit-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytabkit-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 321.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for pytabkit-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 860758720003703d99d1f7ed087a046395f91d6bc1e4317f472185424e657c2c
MD5 70601b1b05788833a795bf0520467bc0
BLAKE2b-256 5840ad5ccc8e3608e32e854a7536db2a9aab312fc361a110092541a728bc8b23

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page