Skip to main content

TabPFN-Wide: Extension of the TabPFN-2 foundation model, specifically designed for wide datasets (many features, few samples).

Project description

TabPFN-Wide

DOI: 10.48550/arXiv.2510.06162 Citation Badge Python Versions License

[!NOTE] DOI: 10.48550/arXiv.2510.06162
Authors: Christopher Kolberg*, Jules Kreuer*, Jonas Huurdeman*, Sofiane Ouaari, Katharina Eggensperger, Nico Pfeifer
Built with PriorLabs-TabPFN.

TabPFN-Wide is an extension of the TabPFN-2 foundation model, specifically designed for wide datasets (many features, few samples), such as multi-omics data. It allows for training and evaluating large-scale tabular models that can handle thousands of features.

[!IMPORTANT] This repository provides a release (v0.1.0) of the tabpfnwide package along with a suite of scripts for training, feature-smearing analysis, and biological interpretation used in the TabPFN-Wide paper. Latter releases will include bug fixes and new features and are not related to the original TabPFN-Wide paper.

Publication

The TabPFN-Wide preprint is available at arXiv.

License

The model weights and code of the tabpfnwide project are licensed under the Prior Labs License Version 1.1.

[!IMPORTANT] The license includes an attribution requirement. If you use this work to improve an AI model, you must include "TabPFN" in the model name and display "Built with PriorLabs-TabPFN". See LICENSE for details.

Quick Start

Installation

Using pip:

pip install tabpfnwide

From Source:

pip install "tabpfnwide @ git+https://github.com/not-a-feature/TabPFN-Wide.git"

Model Weights

Model weights are automatically downloaded from GitHub Releases upon first use and cached in ~/.tabpfnwide/models/. If you are running in an offline environment, you can manually download the .pt files from the Releases page and place them in that directory.

Basic Usage

TabPFN-Wide works just like a scikit-learn classifier.

from tabpfnwide.classifier import TabPFNWideClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

# Load a 'wide' dataset (or any tabular data)
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize the classifier with a wide model (e.g., handles up to 5k features)
clf = TabPFNWideClassifier(model_name="wide-v2-5k", device="cpu")

# Fit and predict
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)

See demo_run_prediction.py for more details

Interpreting the Model

TabPFN-Wide provides built-in tools to extract attention-based feature importance.

import numpy as np

# Note: Attention maps require n_estimators=1 and features_per_group=1
clf = TabPFNWideClassifier(
    model_name="wide-v2-5k",
    save_attention_maps=True,
    n_estimators=1,
    features_per_group=1
)
clf.fit(X_train, y_train)

# Attention maps are recorded during the forward pass, so call predict
# (or predict_proba) before reading them. Each call resets the buffers,
# so the readouts below reflect this most recent forward pass only.
clf.predict_proba(X_test)

# 1. Get raw feature-to-feature attention maps (list of maps per layer)
attn_per_layer = clf.get_attention_maps()
avg_attn_matrix = np.mean(attn_per_layer, axis=0)

# 2. Get direct feature importance based on label-to-feature attention
# This represents how much the model's prediction 'attends' to each input feature
importances = clf.get_attention_to_label()

See demo_attention_maps.py for more details.

Continued Pretraining

[!NOTE] Training scripts require the dev dependencies.

The training logic is contained in the training/ directory. You can run training jobs using the provided python script or shell wrapper.

Using the Python script:

python training/train.py \
    --prior_type mlp_scm \
    --prior_max_features 100 \
    --batch_size 8

Using the shell script:

bash training/train.sh

Evaluation & Analysis

bash analysis/run_analysis.sh "$CHECKPOINT_PATH" "$OUTPUT_DIR"

See analysis/analysis.sbatch for more details.


For the original TabPFN work, please cite:

@article{hollmann2025tabpfn,
 title={Accurate predictions on small data with a tabular foundation model},
 author={Hollmann, Noah and M{\"u}ller, Samuel and Purucker, Lennart and
         Krishnakumar, Arjun and K{\"o}rfer, Max and Hoo, Shi Bin and
         Schirrmeister, Robin Tibor and Hutter, Frank},
 journal={Nature},
 year={2025},
 month={01},
 day={09},
 doi={10.1038/s41586-024-08328-6},
 publisher={Springer Nature},
 url={https://www.nature.com/articles/s41586-024-08328-6},
}

Development & Support

Contact: For issues, please open a ticket on the Issue Tracker.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabpfnwide-0.3.0.tar.gz (21.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tabpfnwide-0.3.0-py3-none-any.whl (17.6 kB view details)

Uploaded Python 3

File details

Details for the file tabpfnwide-0.3.0.tar.gz.

File metadata

  • Download URL: tabpfnwide-0.3.0.tar.gz
  • Upload date:
  • Size: 21.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for tabpfnwide-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1bb8021b1e0087581c95270146f4383d3183661d477da2144897dca6c5c35b17
MD5 ad4d78840b4df928ce0a28ef41e1dab9
BLAKE2b-256 994a61d61b78d25956b787f49e4395724fb1f1be28b2978df7ef75af3ea2f2a3

See more details on using hashes here.

Provenance

The following attestation bundles were made for tabpfnwide-0.3.0.tar.gz:

Publisher: release.yaml on not-a-feature/TabPFN-Wide

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file tabpfnwide-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: tabpfnwide-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 17.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for tabpfnwide-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6a00ee2192c037e6f7a96cab67ab03d3b40a9ffa71c63e258c2820892952df94
MD5 fe127f051218446facac487f1ee1e520
BLAKE2b-256 700f4cc65674a50da443f8d6305d168fcef3576c92320584c0893c77952ab80b

See more details on using hashes here.

Provenance

The following attestation bundles were made for tabpfnwide-0.3.0-py3-none-any.whl:

Publisher: release.yaml on not-a-feature/TabPFN-Wide

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page