Transformer adapted for tabular data domain
Project description
Transformer adapted for tabular data domain
TabularTransformer is a lightweight, end-to-end deep learning framework built with PyTorch, leveraging the power of the Transformer architecture. It is designed to be scalable and efficient with the following advantages:
- Streamlined workflow with no need for preprocessing or handling missing values.
- Unleashing the power of Transformer on tabular data domain.
- Native GPU support through PyTorch.
- Minimal APIs to get started quickly.
- Capable of handling large-scale data.
Get Started and Documentation
Our primary documentation is at https://echosprint.github.io/TabularTransformer/ and is generated from this repository.
Installation:
$ pip install tabular-transformer
Usage
Here we take Adult Income dataset as an example to show the usage of tabular_transformer
package, more examples see the notebooks folder in this repo.
import tabular_transformer as ttf
import torch
income_dataset_path = ttf.prepare_income_dataset()
categorical_cols = [
'workclass', 'education',
'marital.status', 'occupation',
'relationship', 'race', 'sex',
'native.country', 'income']
numerical_cols = [
'age', 'fnlwgt', 'education.num',
'capital.gain', 'capital.loss',
'hours.per.week']
income_reader = ttf.DataReader(
file_path=income_dataset_path,
ensure_categorical_cols=categorical_cols,
ensure_numerical_cols=numerical_cols,
label='income',
)
split = income_reader.split_data({'test': 0.2, 'train': -1})
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
and torch.cuda.is_bf16_supported() else 'float16'
ts = ttf.TrainSettings(device=device, dtype=dtype)
tp = ttf.TrainParameters(max_iters=3000, learning_rate=5e-4,
output_dim=1, loss_type='BINCE',
batch_size=128, eval_interval=100,
eval_iters=20, warmup_iters=100,
validate_split=0.2)
hp = ttf.HyperParameters(dim=64, n_layers=6)
trainer = ttf.Trainer(hp=hp, ts=ts)
trainer.train(
data_reader=income_reader(file_path=split['train']),
tp=tp)
predictor = ttf.Predictor(checkpoint='out/ckpt.pt')
predictor.predict(
data_reader=income_reader(file_path=split['test']),
save_as="prediction_income.csv"
)
Comparison
We used Higgs dataset to conduct our comparison experiment. Details of data are listed in the following tables:
Training Samples | Features | Test Set Description | Task |
---|---|---|---|
10,500,000 | 28 | Last 500,000 samples as the test set | Binary classification |
We computed accuracy metric only on the test data set. check benchmark source.
Data | Metric | XGBoost | XGBoost_Hist | LightGBM | TabularTransformer |
---|---|---|---|---|---|
Higgs | AUC | 0.839593 | 0.845314 | 0.845724 | 0.848628 |
To reproduce the result, please check the source code
Support
Open bug reports and feature requests on GitHub issues.
Reference Papers
Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin. "TabTransformer: Tabular Data Modeling Using Contextual Embeddings". arXiv, 2020.
Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan. "Supervised Contrastive Learning". arXiv, 2020.
Levin, Roman and Cherepanova, Valeriia and Schwarzschild, Avi and Bansal, Arpit and Bruss, C Bayan and Goldstein, Tom and Wilson, Andrew Gordon and Goldblum, Micah. "Transfer Learning with Deep Tabular Models". arXiv, 2022.
License
This project is licensed under the terms of the MIT license. See LICENSE for additional details.
Thanks
The prototype of this project is adapted from python parts of Andrej Karpathy's Llama2.c, Andrej is a mentor, wish him great success with his startup.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for tabular_transformer-0.2.5.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3b5b01402d4f89a6cdc652dff2b4cd5df0a51d3c7f17f8ba4c0e459fa1dffe74 |
|
MD5 | 9e034a71026727a38e9b7ab1174bf10c |
|
BLAKE2b-256 | dacedba4faf8fc49a6badb28fff00a5f6f542ed7d214e3b61055e679953590d5 |
Hashes for tabular_transformer-0.2.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f30e9e7cc322ccd18eefeae7bc27874b91a63db7cb4da53b4579416bc1f82eef |
|
MD5 | 1d13a3ba3e6ca25d0533e3d701004250 |
|
BLAKE2b-256 | e9951fe29d43750c9fa1b416fc20566fb2297ff0fd6d72c334390deef3f5a758 |