Skip to main content

A standard framework for using Deep Learning for tabular data

Project description

PyTorch Tabular

pypi travis documentation status PyPI - Downloads DOI contributions welcome Open In Colab

PyTorch Tabular aims to make Deep Learning with Tabular data easy and accessible to real-world cases and research alike. The core principles behind the design of the library are:

  • Low Resistance Useability
  • Easy Customization
  • Scalable and Easier to Deploy

It has been built on the shoulders of giants like PyTorch(obviously), and PyTorch Lightning.

Table of Contents

Installation

Although the installation includes PyTorch, the best and recommended way is to first install PyTorch from here, picking up the right CUDA version for your machine.

Once, you have got Pytorch installed, just use:

 pip install pytorch_tabular[all]

to install the complete library with extra dependencies.

And :

 pip install pytorch_tabular

for the bare essentials.

The sources for pytorch_tabular can be downloaded from the Github repo_.

You can either clone the public repository:

git clone git://github.com/manujosephv/pytorch_tabular

Once you have a copy of the source, you can install it with:

python setup.py install

Documentation

For complete Documentation with tutorials visit []

Available Models

To implement new models, see the How to implement new models tutorial. It covers basic as well as advanced architectures.

Usage

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig

data_config = DataConfig(
    target=['target'], #target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
    batch_size=1024,
    max_epochs=100,
    gpus=1, #index of the GPU to use. 0, means CPU
)
optimizer_config = OptimizerConfig()

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU", # Activation between each layers
    learning_rate = 1e-3
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)
result = tabular_model.evaluate(test)
pred_df = tabular_model.predict(test)
tabular_model.save_model("examples/basic")
loaded_model = TabularModel.load_from_checkpoint("examples/basic")

Blogs

Future Roadmap(Contributions are Welcome)

  1. Add GaussRank as Feature Transformation
  2. Add ability to use custom activations in CategoryEmbeddingModel
  3. Add differential dropouts(layer-wise) in CategoryEmbeddingModel
  4. Add Fourier Encoding for cyclic time variables
  5. Integrate Optuna Hyperparameter Tuning
  6. Add Text and Image Modalities for mixed modal problems
  7. Add Variable Importance
  8. Integrate SHAP for interpretability

DL Models

  1. DNF-Net: A Neural Architecture for Tabular Data
  2. Attention augmented differentiable forest for tabular data
  3. XBNet : An Extremely Boosted Neural Network
  4. Revisiting Deep Learning Models for Tabular Data

Citation

If you use PyTorch Tabular for a scientific publication, we would appreciate citations to the published software and the following paper:

@misc{joseph2021pytorch,
      title={PyTorch Tabular: A Framework for Deep Learning with Tabular Data}, 
      author={Manu Joseph},
      year={2021},
      eprint={2104.13638},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
  • Zenodo Software Citation
@article{manujosephv_2021, 
    title={manujosephv/pytorch_tabular: v0.5.0-alpha}, 
    DOI={10.5281/zenodo.4732773}, 
    abstractNote={<p>First Alpha Release</p>}, 
    publisher={Zenodo}, 
    author={manujosephv}, 
    year={2021}, 
    month={May}
}

History

0.0.1 (2021-01-26)

  • First release on PyPI.

0.2.0 (2021-02-07)

  • Fixed an issue with torch.clip and torch version
  • Fixed an issue with gpus parameter in TrainerConfig, by setting default value to None for CPU
  • Added feature to use custom sampler in the training dataloader
  • Updated documentation and added a new tutorial for imbalanced classification

0.3.0 (2021-03-02)

  • Fixed a bug on inference

0.4.0 (2021-03-18)

  • Added AutoInt Model
  • Added Mixture Density Networks
  • Refactored the classes to separate backbones from the head of the models
  • Changed the saving and loading model to work for custom parameters that you pass in fit

0.5.0 (2021-03-18)

  • Added more documentation
  • Added Zenodo citation

0.6.0 (2021-06-21)

  • Upgraded versions of PyTorch Lightning to 1.3.6
  • Changed the way gpus parameter is handled to avoid confusion. None is CPU, -1 is all GPUs, int is number of GPUs
  • Added a few more Trainer Params like deterministic, auto_select_gpus
  • Some bug fixes and changes to docs
  • Added seed_everything to the fit method to ensure reproducibility
  • Refactored data_aware_initialization to be part of the BaseModel. Inherited Models can override the method to implement data aware initialization techniques

0.7.0 (2021-09-01)

  • Implemented TabTransformer and FTTransformer models
  • Included capability to save a model using GPU an load in CPU
  • Made the temp folder pytorch tabular specific to avoid conflicts with other tmp folders.
  • Some bug fixes
  • Edited an error out of Advanced Tutorial in docs

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_tabular-0.7.0.tar.gz (2.0 MB view hashes)

Uploaded source

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page