Transformer adapted for tabular data domain
Project description
Transformer adapted for tabular data domain
TabularTransformer is a lightweight, end-to-end deep learning framework built with PyTorch, leveraging the power of the Transformer architecture. It is designed to be scalable and efficient with the following advantages:
- Streamlined workflow with no need for preprocessing or handling missing values.
- Unleashing the power of Transformer on tabular data domain.
- Native GPU support through PyTorch.
- Minimal APIs to get started quickly.
- Capable of handling large-scale data.
Get Started and Documentation
Our primary documentation is at https://echosprint.github.io/TabularTransformer/ and is generated from this repository.
Installation:
$ pip install tabular-transformer
Usage
Here we use Adult Income dataset as an example to show the usage of tabular_transformer
package, more examples see the notebooks folder in this repo.
import tabular_transformer as ttf
import pandas as pd
import torch
# download the dataset
income_dataset_path = ttf.prepare_income_dataset()
class IncomeDataReader(ttf.DataReader):
# make sure interpret columns correctly
ensure_categorical_cols = []
ensure_numerical_cols = []
# load data
def read_data_file(self, file_path):
df = pd.read_csv(file_path)
return df
income_reader = IncomeDataReader(income_dataset_path)
# split 20% as `test`, rest of as `train`
split = income_reader.split_data({'test': 0.2, 'train': -1})
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
and torch.cuda.is_bf16_supported() else 'float16'
ts = ttf.TrainSettings(device=device, dtype=dtype)
tp = ttf.TrainParameters(train_epochs=15, learning_rate=5e-4,
batch_size=128, eval_interval=100,
eval_iters=20, warmup_iters=100,
validate_split=0.2)
hp = ttf.HyperParameters(dim=64, n_layers=6)
# use `HyperParameters` and `TrainSettings` to construct `Trainer`
trainer = ttf.Trainer(hp=hp, ts=ts)
# use split `train` to train with `TrainParameters`
trainer.train(data_reader=IncomeDataReader(split['train']), tp=tp)
# load Pytorch checkpoint
predictor = ttf.Predictor(checkpoint='out/ckpt.pt')
# use split `test` to predict
predictor.predict(
data_reader=IncomeDataReader(split['test']),
save_as="prediction_income.csv"
)
Support
Open bug reports and feature requests on GitHub issues.
Reference Papers
Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin. "TabTransformer: Tabular Data Modeling Using Contextual Embeddings". arXiv, 2020.
Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan. "Supervised Contrastive Learning". arXiv, 2020.
Levin, Roman and Cherepanova, Valeriia and Schwarzschild, Avi and Bansal, Arpit and Bruss, C Bayan and Goldstein, Tom and Wilson, Andrew Gordon and Goldblum, Micah. "Transfer Learning with Deep Tabular Models". arXiv, 2022.
License
This project is licensed under the terms of the MIT license. See LICENSE for additional details.
Thanks
The prototype of this project is adapted from python parts of Andrej Karpathy's Llama2.c, Andrej is a mentor, wish him great success with his startup.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for tabular_transformer-0.1.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4cf323c4ec7ac1d49157f699b4d4f7d30c839619c39057f8ea6911dda50a6ca1 |
|
MD5 | 5bb8c51503455a2e215ab15e7c75ff31 |
|
BLAKE2b-256 | 0c6e53909c7d35e120d0d82b7feaa533d5a0e96733858b01561935bd96f9ef07 |
Hashes for tabular_transformer-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b99e042bc885e33f839ad12646f215ba1ed88859a728a62ceaf88184b3f0f49a |
|
MD5 | 62e6e9b614989ad8891f81490778e892 |
|
BLAKE2b-256 | aad856aea3c1b62c1a27b290ab55c698393335bd1e52421cd90705111080a992 |