Skip to main content

No project description provided

Project description

TabFM: Tabular Foundation Models

TabFM (Tabular Foundation Model) is a scikit-learn compatible tabular foundation model. It allows you to perform zero-shot classification and regression on tabular datasets with mixed column types out-of-the-box.

At inference time, TabFM does not require training parameters on your dataset; instead, it leverages in-context learning by reading your training data as "context" to make instant predictions on new test samples.

This is not an officially supported Google product.


Installation

To install TabFM, clone the repository and install it locally with the backend of your choice:

JAX (CPU):

git clone https://github.com/google-research/tabfm.git
cd tabfm
pip install -e .[jax]

JAX (GPU):

git clone https://github.com/google-research/tabfm.git
cd tabfm
pip install -e .[jax,cuda]

PyTorch (CPU/GPU):

git clone https://github.com/google-research/tabfm.git
cd tabfm
pip install -e .[pytorch]

Note: For PyTorch with GPU support, ensure you have the appropriate PyTorch version installed for your CUDA version before installing TabFM.

Requirements

For a complete list of pinned dependencies and versions, please see requirements.txt. The core requirements depend on the backend you choose:

  • Python >= 3.11
  • Hugging Face Hub (for downloading pre-trained weights)
  • JAX Backend:
    • JAX (specifically jax==0.10.1)
    • Flax (specifically flax==0.12.7, using the modern flax.nnx API)
  • PyTorch Backend:
    • PyTorch (specifically torch==2.12.1+cpu or a GPU version)

Quick Start (TabFM v1.0.0)

We provide pre-trained weights for the TabFM v1.0.0 release. The library handles downloading and loading these weights automatically. You can choose to load the model using either the JAX or PyTorch backend.

1. Classification Example

import numpy as np
import pandas as pd
from tabfm import TabFMClassifier

# Choose your backend:

# OPTION A: JAX Backend
from tabfm import tabfm_v1_0_0_jax as tabfm_v1_0_0
model = tabfm_v1_0_0.load()

# OPTION B: PyTorch Backend
# from tabfm import tabfm_v1_0_0_pytorch as tabfm_v1_0_0
# model = tabfm_v1_0_0.load()

# Initialize scikit-learn compatible classifier (works with either backend model)
clf = TabFMClassifier(model=model)

# Prepare your dataset (supports mixed numerical and categorical features)
X_train = pd.DataFrame({
    "age": [25.0, 45.0, 35.0, 50.0],
    "job": ["engineer", "manager", "engineer", "manager"],
    "income": [80000, 120000, 90000, 130000]
})
y_train = np.array(["low_risk", "high_risk", "low_risk", "high_risk"])

X_test = pd.DataFrame({
    "age": [30.0, 48.0],
    "job": ["engineer", "manager"],
    "income": [85000, 125000]
})

# Fit classifier (prepares ordinal encoders and numerical scalers)
clf.fit(X_train, y_train)

# Predict classes and probabilities
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)

print("Predictions:", predictions)
print("Class Probabilities:\n", probabilities)

2. Regression Example

import numpy as np
import pandas as pd
from tabfm import TabFMRegressor

# Choose your backend:

# OPTION A: JAX Backend
from tabfm import tabfm_v1_0_0_jax as tabfm_v1_0_0
model = tabfm_v1_0_0.load()

# OPTION B: PyTorch Backend
# from tabfm import tabfm_v1_0_0_pytorch as tabfm_v1_0_0
# model = tabfm_v1_0_0.load()

# Initialize scikit-learn compatible regressor (works with either backend model)
reg = TabFMRegressor(model=model)

# Prepare your dataset
X_train = pd.DataFrame({
    "sqft": [1200, 2500, 1500, 3000],
    "neighborhood": ["A", "B", "A", "C"]
})
y_train = np.array([250000, 550000, 310000, 620000])

X_test = pd.DataFrame({
    "sqft": [1800, 2800],
    "neighborhood": ["A", "B"]
})

# Fit and Predict
reg.fit(X_train, y_train)
predictions = reg.predict(X_test)

print("Predicted Prices:", predictions)

Examples Directory

You can find runnable scripts for both classification and regression under the examples/ folder:

To run them, simply execute:

python examples/classification_example.py

(You can edit these files to switch between JAX and PyTorch backends as shown in the comments inside them).


Running Tests

You can run the unit tests directly using Python's unittest module:

# Run all tests (requires both JAX and PyTorch installed)
PYTHONPATH=. python3 -m unittest discover -s tabfm/src/ -p "*_test.py"

# Or run specific test files:
PYTHONPATH=. python3 -m unittest tabfm/src/pytorch/model_test.py
PYTHONPATH=. python3 -m unittest tabfm/src/classifier_and_regressor_pytorch_test.py

Alternatively, if you have Bazel installed, you can run tests with:

bazel test //...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabfm-1.0.0.tar.gz (84.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tabfm-1.0.0-py3-none-any.whl (78.8 kB view details)

Uploaded Python 3

File details

Details for the file tabfm-1.0.0.tar.gz.

File metadata

  • Download URL: tabfm-1.0.0.tar.gz
  • Upload date:
  • Size: 84.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.34.2

File hashes

Hashes for tabfm-1.0.0.tar.gz
Algorithm Hash digest
SHA256 0e77dc167c8569207f563faa2057373f5d3e8225a6849e4dae1c6495564a7125
MD5 f2790d46af458c3456145cdc5882c4ab
BLAKE2b-256 8df8b4d696f7693c62ec3e2132e50119fcc4fc9aa2f1a71ac1d0ebcbbfa58264

See more details on using hashes here.

File details

Details for the file tabfm-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: tabfm-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 78.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.34.2

File hashes

Hashes for tabfm-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1ecb937235878cffd7e642acd35b7cfbbc1c7c44f3bf0fde17011e783b65258a
MD5 fc6d90bf40c50b05019a4744c8a507a6
BLAKE2b-256 dba825975fbb1ec166783f51d3f31ef7dedc126e56e314be09939b373c660b11

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page