Skip to main content

Interface for using TabPFN and library to train TabPFN'

Project description

TabPFN

The TabPFN is a neural network that learned to do tabular data prediction. This is the original CUDA-supporting pytorch impelementation.

We created a Colab, that lets you play with our scikit-learn interface.

Installation

pip install tabpfn

If you want to train and evaluate our method like we did in the paper (including baselines) please install with

pip install tabpfn[full]

To run the autogluon and autosklearn baseline please create a separate environment and install autosklearn==0.14.5 / autogluon==0.4.0, installation in the same environment as our other baselines is not possible.

Getting started

A simple usage of our sklearn interface is:

from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

print('Accuracy', accuracy_score(y_test, y_eval))

TabPFN Usage

TabPFN is different from other methods you might know for tabular classification. Here, we list some tips and tricks that might help you understand how to use it best.

  • Do not preprocess inputs to TabPFN. TabPFN pre-processes inputs internally. It applies a z-score normalization (x-train_x.mean()/train_x.std()) per feature (fitted on the training set) and log-scales outliers heuristically. Finally, TabPFN applies a PowerTransform to all features for every second ensemble member. Pre-processing is important for the TabPFN to make sure that the real-world dataset lies in the distribution of the synthetic datasets seen during training. So to get the best results, do not apply a PowerTransformation to the inputs.
  • TabPFN expects scalar values only (you need to encode categoricals as integers e.g. with OrdinalEncoder). It works best on data that does not contain any categorical or NaN data (see Appendix B.1).
  • TabPFN ensembles multiple input encodings per default. It feeds different index rotations of the features and labels to the model per ensemble member. You can control the ensembling with TabPFNClassifier(...,N_ensemble_configurations=?)
  • TabPFN does not use any statistics from the test set. That means predicting each test example one-by-one will yield the same result as feeding the whole test set together.
  • TabPFN is differentiable in principle, only the pre-processing is not and relies on numpy.

Our Paper

Read our paper for more information about the setup (or contact us ☺️). If you use our method, please cite us using

@inproceedings{
  hollmann2023tabpfn,
  title={Tab{PFN}: A Transformer That Solves Small Tabular Classification Problems in a Second},
  author={Noah Hollmann and Samuel M{\"u}ller and Katharina Eggensperger and Frank Hutter},
  booktitle={The Eleventh International Conference on Learning Representations},
  year={2023},
  url={https://openreview.net/forum?id=cp5PvcI6w8_}
}

License

Copyright 2022 Noah Hollmann, Samuel Müller, Katharina Eggensperger, Frank Hutter

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabpfn-0.1.10.tar.gz (598.2 kB view details)

Uploaded Source

Built Distribution

tabpfn-0.1.10-py3-none-any.whl (156.6 kB view details)

Uploaded Python 3

File details

Details for the file tabpfn-0.1.10.tar.gz.

File metadata

  • Download URL: tabpfn-0.1.10.tar.gz
  • Upload date:
  • Size: 598.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.13

File hashes

Hashes for tabpfn-0.1.10.tar.gz
Algorithm Hash digest
SHA256 4567b288bb9c8ba5971cf0634e4b0cfd752a31f751564691cb42f6fc9983fc7b
MD5 bbbd5dff63b0b22269bf954c3225fd3a
BLAKE2b-256 ec4d4665cf3da497d1f4bd63ad7a7c9d804629490d1740b6bbe9df7d47a32803

See more details on using hashes here.

File details

Details for the file tabpfn-0.1.10-py3-none-any.whl.

File metadata

  • Download URL: tabpfn-0.1.10-py3-none-any.whl
  • Upload date:
  • Size: 156.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.13

File hashes

Hashes for tabpfn-0.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 779b4274d4472990735ff53f0279c409f4a8e38bca1254966af3562e3ec970e0
MD5 44b52105a3b45a20cc38e306a8ea3e5f
BLAKE2b-256 4d755684251c20807d52b1c1a727aa654a46962919d71d7d826819d20b35e0dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page