Skip to main content

Transformer adapted for tabular data domain

Project description

Transformer adapted for tabular data domain

TabularTransformer is a lightweight, end-to-end deep learning framework built with PyTorch, leveraging the power of the Transformer architecture. It is designed to be scalable and efficient with the following advantages:

  • Streamlined workflow with no need for preprocessing or handling missing values.
  • Unleashing the power of Transformer on tabular data domain.
  • Native GPU support through PyTorch.
  • Minimal APIs to get started quickly.
  • Capable of handling large-scale data.

Get Started and Documentation

Our primary documentation is at https://echosprint.github.io/TabularTransformer/ and is generated from this repository.

Installation:

$ pip install tabular-transformer

Usage

Here we take Adult Income dataset as an example to show the usage of tabular-transformer package, more examples see the notebooks folder in this repo.

Open In Colab
import tabular_transformer as ttf
import torch

income_dataset_path = ttf.prepare_income_dataset()

categorical_cols = [
    'workclass', 'education',
    'marital.status', 'occupation',
    'relationship', 'race', 'sex',
    'native.country', 'income']

# all remaining columns are numerical
numerical_cols = []

income_reader = ttf.DataReader(
    file_path=income_dataset_path,
    ensure_categorical_cols=categorical_cols,
    ensure_numerical_cols=numerical_cols,
    label='income',
)

split = income_reader.split_data({'test': 0.2, 'train': -1})

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

ts = ttf.TrainSettings(device=device, dtype=dtype)

tp = ttf.TrainParameters(max_iters=3000, learning_rate=5e-4,
                         output_dim=1, loss_type='BINCE',
                         batch_size=128, eval_interval=100,
                         eval_iters=20, warmup_iters=100,
                         validate_split=0.2)

hp = ttf.HyperParameters(dim=64, n_layers=6)

trainer = ttf.Trainer(hp=hp, ts=ts)

trainer.train(
    data_reader=income_reader(file_path=split['train']),
    tp=tp)

predictor = ttf.Predictor(checkpoint='out/ckpt.pt')

predictor.predict(
    data_reader=income_reader(file_path=split['test']),
    save_as="prediction_income.csv"
)

Comparison

We used Higgs dataset to conduct our comparison experiment. Details of data are listed in the following tables:

Training Samples Features Test Set Description Task
10,500,000 28 Last 500,000 samples as the test set Binary classification

We computed accuracy metric only on the test data set. check benchmark source.

Data Metric XGBoost XGBoost_Hist LightGBM TabularTransformer
Higgs AUC 0.839593 0.845314 0.845724 0.848628

To reproduce the result, please check the source code

Support

Open bug reports and feature requests on GitHub issues.

Reference Papers

Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin. "TabTransformer: Tabular Data Modeling Using Contextual Embeddings". arXiv, 2020.

Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan. "Supervised Contrastive Learning". arXiv, 2020.

Levin, Roman and Cherepanova, Valeriia and Schwarzschild, Avi and Bansal, Arpit and Bruss, C Bayan and Goldstein, Tom and Wilson, Andrew Gordon and Goldblum, Micah. "Transfer Learning with Deep Tabular Models". arXiv, 2022.

License

This project is licensed under the terms of the MIT license. See LICENSE for additional details.

Thanks

The prototype of this project is adapted from python parts of Andrej Karpathy's Llama2.c, Andrej is a mentor, wish him great success with his startup.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabular_transformer-0.3.0.tar.gz (34.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tabular_transformer-0.3.0-py3-none-any.whl (37.2 kB view details)

Uploaded Python 3

File details

Details for the file tabular_transformer-0.3.0.tar.gz.

File metadata

  • Download URL: tabular_transformer-0.3.0.tar.gz
  • Upload date:
  • Size: 34.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for tabular_transformer-0.3.0.tar.gz
Algorithm Hash digest
SHA256 72a9283cc18e3b1aa6b35d5e06f8c54ed48e1efc2afa1b7cfbe181fffed8767a
MD5 7698331253000af1dc376605dbfa1d91
BLAKE2b-256 0e7530ff78bd2774f0386e05175fc0c71a11217cdca82047fb8f58abd3a5dad1

See more details on using hashes here.

File details

Details for the file tabular_transformer-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for tabular_transformer-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 275b769966429ed32343cea0892b84ffc4ce72379d9b8cc29a248664d690c8f0
MD5 350577296910be4276db7358db6f745d
BLAKE2b-256 27e45d52d95acdc4bce9eebd6b6caa8f407716dbce13e38099ab51c72962b3a8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page