Skip to main content

TabPFN: Foundation model for tabular data

Project description

TabPFN Extensions

PyPI version Downloads License Discord Twitter Follow Contributions Welcome Last Commit colab

[!WARNING]

Experimental Code Notice

Please note that the extensions in this repository are experimental.

  • They are less rigorously tested than the core tabpfn library.
  • APIs are subject to change without notice in future releases. We welcome your feedback and contributions to help improve and stabilize them!

Interactive Notebook Tutorial

[!TIP]

Dive right in with our interactive Colab notebook! It's the best way to get a hands-on feel for TabPFN, walking you through installation, classification, and regression examples.

Open In Colab

Installation

# Clone and install the repository
pip install "tabpfn-extensions[all] @ git+https://github.com/PriorLabs/tabpfn-extensions.git"

Available Extensions

  • post_hoc_ensembles: Improve performance with model combination
  • interpretability: Explain TabPFN predictions with SHAP values and feature selection
  • many_class: Handle classification with more classes than TabPFN's default limit
  • classifier_as_regressor: Use TabPFN's classifier for regression tasks
  • hpo: Automatic hyperparameter tuning for TabPFN
  • rf_pfn: Combine TabPFN with decision trees and random forests
  • unsupervised: Data generation and outlier detection
  • embedding: Get TabPFNs internal dense sample embeddings
  • tabebm: Data augmentation using TabPFN-based Energy-Based Models
  • pval_crt: Statistical feature relevance testing (p-values)

See the Documentation section below for guides, examples, and per-extension READMEs.

Backend Options

Many TabPFN Extensions works with two TabPFN implementations:

  1. ** TabPFN Package** - Full PyTorch implementation for local inference:

    pip install tabpfn
    
  2. ** TabPFN Client** - Lightweight API client for cloud-based inference:

    pip install tabpfn-client
    

Choose the backend that fits your needs - most extensions work with either option!

Exceptions to this are post_hoc_ensembles and embedding, which only work with the local tabpfn package.

Documentation

Documentation for tabpfn-extensions is spread across several sources. If you are new to the project, the examples are usually the fastest way to get started; for deeper conceptual guides, see the TabPFN Docs pages.

Examples

Runnable scripts and notebooks for extensions and general use cases live in the examples/ directory of this repository:

  • embedding/ — access TabPFN's internal dense sample embeddings
  • hpo/ — automatic hyperparameter tuning
  • interpretability/ — SHAP values, partial dependence plots, feature selection
  • large_datasets/ — working with datasets beyond TabPFN's default limits
  • many_class/ — classification with more than 10 classes
  • phe/ — post-hoc ensembles
  • pval_crt/ — statistical feature relevance testing
  • rf_pfn/ — TabPFN combined with decision trees and random forests
  • survival/ — survival analysis
  • tabebm/ — data augmentation via TabEBM
  • unsupervised/ — data generation, imputation, and outlier detection

TabPFN Docs pages

In-depth guides for selected extensions are available on docs.priorlabs.ai:

Per-extension READMEs

Some extensions ship a dedicated README alongside their source code:

Interactive notebook

The main TabPFN demo notebook also covers several extensions — in particular the unsupervised and interpretability extensions:

Open In Colab

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

TabPFN Workflow at a Glance

Follow this decision tree to build your model and choose the right extensions from our ecosystem. It walks you through critical questions about your data, hardware, and performance needs, guiding you to the best solution for your specific use case.

---
config:
  theme: 'default'
  themeVariables:
    edgeLabelBackground: 'white'
---
graph LR
    %% 1. DEFINE COLOR SCHEME & STYLES
    classDef default fill:#fff,stroke:#333,stroke-width:2px,color:#333;
    classDef start_node fill:#e8f5e9,stroke:#43a047,stroke-width:2px,color:#333;
    classDef process_node fill:#e0f2f1,stroke:#00796b,stroke-width:2px,color:#333;
    classDef decision_node fill:#fff8e1,stroke:#ffa000,stroke-width:2px,color:#333;

    style Infrastructure fill:#fff,stroke:#ccc,stroke-width:5px;
    style Unsupervised fill:#fff,stroke:#ccc,stroke-width:5px;
    style Data fill:#fff,stroke:#ccc,stroke-width:5px;
    style Performance fill:#fff,stroke:#ccc,stroke-width:5px;
    style Interpretability fill:#fff,stroke:#ccc,stroke-width:5px;

    %% 2. DEFINE GRAPH STRUCTURE
    subgraph Infrastructure
        start((Start)) --> gpu_check["GPU available?"];
        gpu_check -- Yes --> local_version["Use TabPFN<br/>(local PyTorch)"];
        gpu_check -- No --> api_client["Use TabPFN-Client<br/>(cloud API)"];
        task_type["What is<br/>your task?"]
    end

    local_version --> task_type
    api_client --> task_type

    end_node((Workflow<br/>Complete));

    subgraph Unsupervised
        unsupervised_type["Select<br/>Unsupervised Task"];
        unsupervised_type --> imputation["Imputation"]
        unsupervised_type --> data_gen["Data<br/>Generation"];
        unsupervised_type --> tabebm["Data<br/>Augmentation"];
        unsupervised_type --> density["Outlier<br/>Detection"];
        unsupervised_type --> embedding["Get<br/>Embeddings"];
    end


    subgraph Data
        data_check["Data Checks"];
        model_choice["Samples > 10k or<br/>Classes > 10?"]
        data_check -- "Table Contains Text Data?" --> api_backend_note["Note: API client has<br/>native text support"];
        api_backend_note --> model_choice;
        data_check -- "Time-Series Data?" --> ts_features["Use Time-Series<br/>Features"];
        ts_features --> model_choice;
        data_check -- "Purely Tabular" --> model_choice;
        model_choice -- "No" --> finetune_check;
        model_choice -- "Yes, >10k samples" --> subsample["Large Datasets Guide<br/>"];
        model_choice -- "Yes, >10 classes" --> many_class["Many-Class<br/>Method"];
    end

    subgraph Performance
        finetune_check["Need<br/>Finetuning?"];
        performance_check["Need Even Better Performance?"];
        speed_check["Need faster inference<br/>at prediction time?"];
        kv_cache["Enable KV Cache<br/>(fit_mode='fit_with_cache')<br/><small>Faster predict; +Memory ~O(N×F)</small>"];
        tuning_complete["Tuning Complete"];

        finetune_check -- Yes --> finetuning["Finetuning"];
        finetune_check -- No --> performance_check;

        finetuning --> performance_check;

        performance_check -- No --> tuning_complete;
        performance_check -- Yes --> hpo["HPO"];
        performance_check -- Yes --> post_hoc["Post-Hoc<br/>Ensembling"];
        performance_check -- Yes --> more_estimators["More<br/>Estimators"];
        performance_check -- Yes --> speed_check;

        speed_check -- Yes --> kv_cache;
        speed_check -- No --> tuning_complete;

        hpo --> tuning_complete;
        post_hoc --> tuning_complete;
        more_estimators --> tuning_complete;
        kv_cache --> tuning_complete;
    end

    subgraph Interpretability

        tuning_complete --> interpretability_check;

        interpretability_check["Need<br/>Interpretability?"];

        interpretability_check --> feature_selection["Feature Selection"];
        interpretability_check --> partial_dependence["Partial Dependence Plots"];
        interpretability_check --> shapley["Explain with<br/>SHAP"];
        interpretability_check --> shap_iq["Explain with<br/>SHAP IQ"];
        interpretability_check -- No --> end_node;

        feature_selection --> end_node;
        partial_dependence --> end_node;
        shapley --> end_node;
        shap_iq --> end_node;

    end

    %% 3. LINK SUBGRAPHS AND PATHS
    task_type -- "Classification or Regression" --> data_check;
    task_type -- "Unsupervised" --> unsupervised_type;

    subsample --> finetune_check;
    many_class --> finetune_check;

    %% 4. APPLY STYLES
    class start,end_node start_node;
    class local_version,api_client,imputation,data_gen,tabebm,density,embedding,api_backend_note,ts_features,subsample,many_class,finetuning,feature_selection,partial_dependence,shapley,shap_iq,hpo,post_hoc,more_estimators,kv_cache process_node;
    class gpu_check,task_type,unsupervised_type,data_check,model_choice,finetune_check,interpretability_check,performance_check,speed_check decision_node;
    class tuning_complete process_node;

    %% 5. ADD CLICKABLE LINKS (INCLUDING KV CACHE EXAMPLE)
    click local_version "https://github.com/PriorLabs/TabPFN" "TabPFN Backend Options" _blank
    click api_client "https://github.com/PriorLabs/tabpfn-client" "TabPFN API Client" _blank
    click api_backend_note "https://github.com/PriorLabs/tabpfn-client" "TabPFN API Backend" _blank
    click unsupervised_type "https://github.com/PriorLabs/tabpfn-extensions" "TabPFN Extensions" _blank
    click imputation "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/unsupervised/imputation.py" "TabPFN Imputation Example" _blank
    click data_gen "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/unsupervised/generate_data.py" "TabPFN Data Generation Example" _blank
    click tabebm "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/tabebm/tabebm_augment_real_world_data.ipynb" "TabEBM Data Augmentation Example" _blank
    click density "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/unsupervised/density_estimation_outlier_detection.py" "TabPFN Density Estimation/Outlier Detection Example" _blank
    click embedding "https://github.com/PriorLabs/tabpfn-extensions/tree/main/examples/embedding" "TabPFN Embedding Example" _blank
    click ts_features "https://github.com/PriorLabs/tabpfn-time-series" "TabPFN Time-Series Example" _blank
    click many_class "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/many_class/many_class_classifier_example.py" "Many Class Example" _blank
    click finetuning "https://github.com/PriorLabs/TabPFN/blob/main/examples/finetune_classifier.py" "Finetuning Example" _blank
    click feature_selection "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/interpretability/feature_selection.py" "Feature Selection Example" _blank
    click partial_dependence "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/interpretability/pdp_example.py" "Partial Dependence Plots Example" _blank
    click shapley "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/interpretability/shap_example.py" "Shapley Values Example" _blank
    click shap_iq "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/interpretability/shapiq_example.py" "SHAP IQ Example" _blank
    click post_hoc "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/phe/phe_example.py" "Post-Hoc Ensemble Example" _blank
    click hpo "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/hpo/tuned_tabpfn.py" "HPO Example" _blank
    click subsample "https://github.com/PriorLabs/tabpfn-extensions/blob/main/examples/large_datasets/large_datasets_example.py" "Large Datasets Example" _blank
    click kv_cache "https://github.com/PriorLabs/TabPFN/blob/main/examples/kv_cache_fast_prediction.py" "KV Cache Fast Prediction Example" _blank

Anonymized Telemetry

This project collects fully anonymous usage telemetry with an option to opt-out of any telemetry or opt-in to extended telemetry.

The data is used exclusively to help us provide stability to the relevant products and compute environments and guide future improvements.

  • No personal data is collected
  • No code, model inputs, or outputs are ever sent
  • Data is strictly anonymous and cannot be linked to individuals

For details on telemetry, please see our Telemetry Reference and our Privacy Policy.

To opt out, set the following environment variable:

export TABPFN_DISABLE_TELEMETRY=1

For Contributors

Interested in adding your own extension? We welcome contributions!

We use uv to manage the project's environment, so install that first.

# Clone and set up for development
git clone https://github.com/PriorLabs/tabpfn-extensions.git
cd tabpfn-extensions
uv sync
source .venv/bin/activate

# If you add optional dependencies for your extension in pyproject.toml, install them
# like this
uv sync --extra [your extension name]

# Test your extension with fast mode
FAST_TEST_MODE=1 pytest tests/test_your_extension.py -v

See our Contribution Guide for more details.

Contributors


Built with ❤️ by the TabPFN community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabpfn_extensions-0.3.0.tar.gz (134.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tabpfn_extensions-0.3.0-py3-none-any.whl (123.6 kB view details)

Uploaded Python 3

File details

Details for the file tabpfn_extensions-0.3.0.tar.gz.

File metadata

  • Download URL: tabpfn_extensions-0.3.0.tar.gz
  • Upload date:
  • Size: 134.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.22

File hashes

Hashes for tabpfn_extensions-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f0e09aaf464eacbc57a6c0ece01a7521cbd7b0a02393aeb8306009fe5860f0f3
MD5 8c0c1cdd2d741fc2de148a4be313f833
BLAKE2b-256 a152558f7a5d2d6686b411ff5390a8bd5982315a9d48e92a33a04c31b20b568e

See more details on using hashes here.

File details

Details for the file tabpfn_extensions-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for tabpfn_extensions-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bbaf2413dae0972d682703d86c1cd6f61c6d32171ab08f3ae5bccab8447dac31
MD5 15f4b300c7fb20fe0cdbddf40092bf45
BLAKE2b-256 46413e67e7517818a3d183f6ff8d62cd3a1b390f96154eef5d45e82e8dfe439e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page