ML models + benchmark for tabular data classification and regression
Project description
PyTabKit: Tabular ML models and benchmarking (NeurIPS 2024)
Paper | Documentation | RealMLP-TD-S standalone implementation | Grinsztajn et al. benchmark code | Data archive |
---|
PyTabKit provides scikit-learn interfaces for modern tabular classification and regression methods benchmarked in our paper, see below. It also contains the code we used for benchmarking these methods on our benchmarks.
Installation
pip install pytabkit
- If you want to use TabR, you have to manually install faiss, which is only available on conda.
- Please install torch separately if you want to control the version (CPU/GPU etc.)
- Use
pytabkit[autogluon,extra,hpo,bench]
to install additional dependencies for AutoGluon models, extra preprocessing, hyperparameter optimization methods beyond random search (hyperopt/SMAC), and the benchmarking part. For the hpo part, you might need to install swig (e.g. via pip) if the build of pyrfr fails. See also the documentation. To run the data download, you need one of rar, unrar, or 7-zip to be installed on the system.
Using the ML models
Most of our machine learning models are directly available via scikit-learn interfaces. For example, you can use RealMLP-TD for classification as follows:
from pytabkit import RealMLP_TD_Classifier
model = RealMLP_TD_Classifier() # or TabR_S_D_Classifier, CatBoost_TD_Classifier, etc.
model.fit(X_train, y_train)
model.predict(X_test)
The code above will automatically select a GPU if available,
try to detect categorical columns in dataframes,
preprocess numerical variables and regression targets (no standardization required),
and use a training-validation split for early stopping.
All of this (and much more) can be configured through the constructor
and the parameters of the fit() method.
For example, it is possible to do bagging
(ensembling of models on 5-fold cross-validation)
simply by passing n_cv=5
to the constructor.
Here is an example for some of the parameters that can be set explicitly:
from pytabkit import RealMLP_TD_Classifier
model = RealMLP_TD_Classifier(device='cpu', random_state=0, n_cv=1, n_refit=0,
n_epochs=256, batch_size=256, hidden_sizes=[256] * 3,
val_metric_name='cross_entropy',
use_ls=False, # for metrics like AUC / log-loss
lr=0.04, verbosity=2)
model.fit(X_train, y_train, X_val, y_val, cat_col_names=['Education'])
model.predict_proba(X_test)
See this notebook for more examples. Missing numerical values are currently not allowed and need to be imputed beforehand.
Available ML models
Our ML models are available in up to three variants, all with best-epoch selection:
- library defaults (D)
- our tuned defaults (TD)
- random search hyperparameter optimization (HPO), sometimes also tree parzen estimator (HPO-TPE)
We provide the following ML models:
- RealMLP (TD, HPO): Our new neural net models with tuned defaults (TD) or random search hyperparameter optimization (HPO)
- XGB, LGBM, CatBoost (D, TD, HPO, HPO-TPE): Interfaces for gradient-boosted tree libraries XGBoost, LightGBM, CatBoost
- MLP, ResNet, FTT (D, HPO): Models from Revisiting Deep Learning Models for Tabular Data
- MLP-PLR: MLP with numerical embeddings from On Embeddings for Numerical Features in Tabular Deep Learning
- TabR-S (D): TabR model from TabR: Tabular Deep Learning Meets Nearest Neighbors
- Ensemble-TD: Weighted ensemble of all TD models (RealMLP, XGB, LGBM, CatBoost)
Benchmarking code
Our benchmarking code has functionality for
- dataset download
- running methods highly parallel on single-node/multi-node/multi-GPU hardware, with automatic scheduling and trying to respect RAM constraints
- analyzing/plotting results
For more details, we refer to the documentation.
Preprocessing code
While many preprocessing methods are implemented in this repository, a standalone version of our robust scaling + smooth clipping can be found here.
Citation
If you use this repository for research purposes, please cite our paper:
@inproceedings{holzmuller2024better,
title={Better by default: {S}trong pre-tuned {MLPs} and boosted trees on tabular data},
author={Holzm{\"u}ller, David and Grinsztajn, Leo and Steinwart, Ingo},
booktitle = {Neural {Information} {Processing} {Systems}},
year={2024}
}
Contributors
- David Holzmüller (main developer)
- Léo Grinsztajn (deep learning baselines, plotting)
- Ingo Steinwart (UCI dataset download)
- Katharina Strecker (PyTorch-Lightning interface)
- Jérôme Dockès (deployment, continuous integration)
Acknowledgements
Code from other repositories is acknowledged as well as possible in code comments. Especially, we used code from https://github.com/yandex-research/rtdl and sub-packages (Apache 2.0 license), code from https://github.com/catboost/benchmarks/ (Apache 2.0 license), and https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html (Apache 2.0 license).
Older versions
This code is for the NeurIPS version and arXiv v2 of the paper. The code for the arXiv v1 version is in older commits and archived at DaRUS.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytabkit-1.0.0.tar.gz
.
File metadata
- Download URL: pytabkit-1.0.0.tar.gz
- Upload date:
- Size: 257.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.27.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 94ed0255143fc143ed92274d40838cfa4d060437e9f3a2f92e2bacb53043441d |
|
MD5 | bf9ea7ac90486e00c9135f26f1876742 |
|
BLAKE2b-256 | daab31e0532ae9210381cb5f9f28cca8c7e5a1480f18ffa54e88a03d3f619849 |
File details
Details for the file pytabkit-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: pytabkit-1.0.0-py3-none-any.whl
- Upload date:
- Size: 302.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.27.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4d1fbbf36a1ed8614eab63daea12eee7827a73c247d8a426935fb6a5e5f45002 |
|
MD5 | 58387857f60a9a3bf52cdd4373c756b2 |
|
BLAKE2b-256 | b9d3234058a01b52e5f847cadc152f400081387f658a56ca43672b2e85392ec7 |