Skip to main content

ML models + benchmark for tabular data classification and regression

Project description

Open In Colab

PyTabKit: Tabular ML models and benchmarking code

This repository accompanies our paper

Better by default: Strong pre-tuned MLPs and boosted trees on tabular data

It contains code for applying tabular ML methods and (optionally) benchmarking them on our meta-train and meta-test benchmarks.

Installation

pip install pytabkit
  • If you want to use TabR, you have to manually install faiss, which is only available on conda
  • Install torch before if you want to control the version (CPU/GPU etc.)
  • Use pytabkit[full] to also install the benchmarking library part. See also the documentation.

Using the ML models

Most of our machine learning models are directly available via scikit-learn interfaces. For example, you can use RealMLP-TD for classification as follows:

from pytabkit.models.sklearn.sklearn_interfaces import RealMLP_TD_Classifier

model = RealMLP_TD_Classifier()  # or TabR_S_D_Classifier, CatBoost_TD_Classifier, etc.
model.fit(X_train, y_train)
model.predict(X_test)

The code above will automatically select a GPU if available, try to detect categorical columns in dataframes, preprocess numerical variables and regression targets (no standardization required), and use a training-validation split for early stopping. All of this (and much more) can be configured through the constructor and the parameters of the fit() method. For example, it is possible to do bagging (ensembling of models on 5-fold cross-validation) simply by passing n_cv=5 to the constructor. Here is an example for some of the parameters that can be set explicitly:

from pytabkit.models.sklearn.sklearn_interfaces import RealMLP_TD_Classifier

model = RealMLP_TD_Classifier(device='cpu', random_state=0, n_cv=1, n_refit=0,
                              verbosity=2, val_metric_name='cross_entropy',
                              n_epochs=256, batch_size=256, hidden_sizes=[256] * 3,
                              lr=0.04, use_ls=False)
model.fit(X_train, y_train, val_idxs=val_idxs, cat_features=cat_features)
model.predict_proba(X_test)

See this notebook for more examples.

Available ML models

Our ML models are available in up to three variants, all with best-epoch selection:

  • library defaults (D)
  • our tuned defaults (TD)
  • random search hyperparameter optimization (HPO), sometimes also tree parzen estimator (HPO-TPE)

We provide the following ML models:

  • RealMLP (TD, HPO): Our new neural net models with tuned defaults (TD) or random search hyperparameter optimization (HPO)
  • XGB, LGBM, CatBoost (D, TD, HPO, HPO-TPE): Interfaces for gradient-boosted tree libraries XGBoost, LightGBM, CatBoost
  • MLP, ResNet (D, HPO): Models from [Revisiting Deep Learning Models for Tabular Data](Revisiting Deep Learning Models for Tabular Data)
  • TabR-S (D): TabR model from TabR: Tabular Deep Learning Meets Nearest Neighbors
  • Ensemble-TD: Weighted ensemble of all TD models (RealMLP, XGB, LGBM, CatBoost)

Benchmarking code

Our benchmarking code has functionality for

  • dataset download
  • running methods highly parallel on single-node/multi-node/multi-GPU hardware, with automatic scheduling and trying to respect RAM constraints
  • analyzing/plotting results

For more details, we refer to the documentation.

Citation

If you use this repository for research purposes, please cite TODO.

Contributors

  • David Holzmüller (Main developer)
  • Léo Grinsztajn (Deep learning baselines, plotting)
  • Ingo Steinwart (UCI dataset download)
  • Katharina Strecker (PyTorch-Lightning interface)

Acknowledgements

Code from other repositories is acknowledged as well as possible in code comments. Especially, we used code from https://github.com/yandex-research/rtdl and sub-packages (Apache 2.0 license), code from https://github.com/catboost/benchmarks/ (Apache 2.0 license), and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html (Apache 2.0 license).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytabkit-0.0.1.tar.gz (233.7 kB view details)

Uploaded Source

Built Distribution

pytabkit-0.0.1-py3-none-any.whl (276.2 kB view details)

Uploaded Python 3

File details

Details for the file pytabkit-0.0.1.tar.gz.

File metadata

  • Download URL: pytabkit-0.0.1.tar.gz
  • Upload date:
  • Size: 233.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for pytabkit-0.0.1.tar.gz
Algorithm Hash digest
SHA256 93315a6a4729c7861c307a57375602f078bbf1cc986c89776f4e7aad23cda754
MD5 c8f6c93d0691fc48f83b367a2432a87b
BLAKE2b-256 c7ff305170602fc9fe707a983ad48ff6e8a1a1dbbac2d18989a2e6a337b504eb

See more details on using hashes here.

File details

Details for the file pytabkit-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: pytabkit-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 276.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.0

File hashes

Hashes for pytabkit-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 59e4deb2dd85cc4ad5b2ed392e476f67ce1d279401b63c9fabaa1656c240adb9
MD5 5af53064ddc8e7092aa1a1d20c9694d5
BLAKE2b-256 546d5c9024c7e70ef59b01cd94da51d3dd247ad8f5b0cb86b4ca04baec56d70e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page