Skip to main content

Adaptation of TabPFN to work with large tabular datasets.

Project description

Ensemble TabPFN

TabPFN is a transformer architecture prosposed by Hollman et al for classification on small tabular datasets. It is a Prior-Data Fitted Network that has been trained once and does not require fine tuning for new datasets. It works by approximating the distribution of new data to the prior synthetic data it has seen during training. In a machine learning pipeline, this network can be "fit" on a training dataset in under a second and can generate predictions for the test set in a single forward pass in the network. However there are limitations in the current architecture, namely, the training dataset can contain only upto 1000 inputs with upto 100 numerical features. In addition, the network can predict only upto 10 classes in a multi-class classification problem. With EnsembleTabPFN, we address two of these issues where we have extended the original model to work with datasets containing more than 1000 samples and 100 features. EnsembleTabPFN is fully compatible with Scikit-learn API and can be used in a modelling pipeline.

Installation

From source

git clone https://github.com/ersilia-os/ensemble-tabpfn.git
cd ensemble-tabpfn
pip install .

From PyPI

pip install ensemble-tabpfn

Using Poetry

git clone https://github.com/ersilia-os/ensemble-tabpfn.git
cd ensemble-tabpfn
poetry install --without dev,test,docs

Usage

from ensemble_tabpfn import EnsembleTabPFN
from sklearn.metrics import accuracy_score

clf = EnsembleTabPFN()
clf.fit(X_train, y_train)
y_hat = clf.predict(y_test)
acc = accuracy_score(y_test, y_hat)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ensemble_tabpfn-0.1.1.tar.gz (17.9 kB view hashes)

Uploaded Source

Built Distribution

ensemble_tabpfn-0.1.1-py3-none-any.whl (19.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page