Skip to main content

PyTorch implementation of TabNet

Project description

README

TabNet : Attentive Interpretable Tabular Learning

This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf.

CircleCI

PyPI version

PyPI - Downloads

Installation

You can install using pip by running: pip install pytorch-tabnet

If you wan to use it locally within a docker container:

git clone git@github.com:dreamquark-ai/tabnet.git

cd tabnet to get inside the repository

make start to build and get inside the container

poetry install to install all the dependencies, including jupyter

make notebook inside the same terminal

You can then follow the link to a jupyter notebook with tabnet installed.

GPU version is available and should be working but is not supported yet.

How to use it?

The implementation makes it easy to try different architectures of TabNet. All you need is to change the network parameters and training parameters. All parameters are quickly describe bellow, to get a better understanding of what each parameters do please refer to the orginal paper.

You can also get comfortable with the code works by playing with the notebooks tutorials for adult census income dataset and forest cover type dataset.

Network parameters

  • input_dim : int

    Number of initial features of the dataset

  • output_dim : int

    Size of the desired output. Ex :

    • 1 for regression task
    • 2 for binary classification
    • N > 2 for multiclass classifcation
  • nd : int

    Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. Values typically range from 8 to 64.

  • na : int

    Width of the attention embedding for each mask. According to the paper nd=na is usually a good choice.

  • n_steps : int Number of steps in the architecture (usually between 3 and 10)

  • gamma : float This is the coefficient for feature reusage in the masks. A value close to 1 will make mask selection least correlated between layers. Values range from 1.0 to 2.0

  • cat_idxs : list of int

    List of categorical features indices.

  • cat_emb_dim : list of int

    List of embeddings size for each categorical features.

  • n_independent : int

    Number of independent Gated Linear Units layers at each step. Usual values range from 1 to 5 (default=2)

  • n_shared : int

    Number of shared Gated Linear Units at each step Usual values range from 1 to 5 (default=2)

  • virtual_batch_size : int

    Size of the mini batches used for Ghost Batch Normalization

Training parameters

  • max_epochs : int (default = 200)

    Maximum number of epochs for trainng.

  • patience : int (default = 15)

    Number of consecutive epochs without improvement before performing early stopping.

  • lr : float (default = 0.02)

    Initial learning rate used for training. As mentionned in the original paper, a large initial learning of 0.02 with decay is a good option.

  • clip_value : float (default None)

    If a float is given this will clip the gradient at clip_value.

  • lambda_sparse : float (default = 1e-3)

    This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help.

  • model_name : str (default = 'DQTabNet')

    Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models.

  • saving_path : str (default = './')

    Path defining where to save models.

  • scheduler_fn : torch.optim.lr_scheduler (default = None)

    Pytorch Scheduler to change learning rates during training.

  • scheduler_params: dict

    Parameters dictionnary for the scheduler_fn. Ex : {"gamma": 0.95, "step_size": 10}

  • verbose : int (default=-1)

    Verbosity for notebooks plots, set to 1 to see every epoch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_tabnet-0.1.2.tar.gz (14.2 kB view details)

Uploaded Source

Built Distribution

pytorch_tabnet-0.1.2-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_tabnet-0.1.2.tar.gz.

File metadata

  • Download URL: pytorch_tabnet-0.1.2.tar.gz
  • Upload date:
  • Size: 14.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.7.5 Linux/5.0.0-29-generic

File hashes

Hashes for pytorch_tabnet-0.1.2.tar.gz
Algorithm Hash digest
SHA256 9161a2a1135fc324039c2d5f056055ec3e38f9d2a37e24dd9bd271f5cfc1c4aa
MD5 c98d4fca7c9611f20eed9379d3d4e76d
BLAKE2b-256 1b3062b421b5e9819a5c6e672c8612be0bff774f8f33355ce39a556b15a80fe8

See more details on using hashes here.

File details

Details for the file pytorch_tabnet-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: pytorch_tabnet-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 13.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.7.5 Linux/5.0.0-29-generic

File hashes

Hashes for pytorch_tabnet-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a9d8c9db46c88d4d830a51551ba7f3984857e187eea2edf6ede984ac7c9290a1
MD5 6b96e7a89df21d60a605d727cd8ad4a6
BLAKE2b-256 f91c1fb71f2b90dffaf5da60be7e25ab74ab2ec6945bfd443cb9b3c6fced39dd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page