Skip to main content

No project description provided

Project description

README

TabNet : Attentive Interpretable Tabular Learning

This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf.

Installation

You can install using pip by running: pip install pytorch-tabnet

If you wan to use it locally within a docker container:

git clone git@github.com:dreamquark-ai/tabnet.git

cd tabnet to get inside the repository

make start to build and get inside the container

poetry install to install all the dependencies, including jupyter

make notebook inside the same terminal

You can then follow the link to a jupyter notebook with tabnet installed.

GPU version is available and should be working but is not supported yet.

How to use it?

The implementation makes it easy to try different architectures of TabNet. All you need is to change the network parameters and training parameters. All parameters are quickly describe bellow, to get a better understanding of what each parameters do please refer to the orginal paper.

You can also get comfortable with the code works by playing with the notebooks tutorials for adult census income dataset and forest cover type dataset.

Network parameters

  • input_dim : int

    Number of initial features of the dataset

  • output_dim : int

    Size of the desired output. Ex :

    • 1 for regression task
    • 2 for binary classification
    • N > 2 for multiclass classifcation
  • nd : int

    Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. Values typically range from 8 to 64.

  • na : int

    Width of the attention embedding for each mask. According to the paper nd=na is usually a good choice.

  • n_steps : int Number of steps in the architecture (usually between 3 and 10)

  • gamma : float This is the coefficient for feature reusage in the masks. A value close to 1 will make mask selection least correlated between layers. Values range from 1.0 to 2.0

  • cat_idxs : list of int

    List of categorical features indices.

  • cat_emb_dim : list of int

    List of embeddings size for each categorical features.

  • n_independent : int

    Number of independent Gated Linear Units layers at each step. Usual values range from 1 to 5 (default=2)

  • n_shared : int

    Number of shared Gated Linear Units at each step Usual values range from 1 to 5 (default=2)

  • virtual_batch_size : int

    Size of the mini batches used for Ghost Batch Normalization

Training parameters

  • max_epochs : int (default = 200)

    Maximum number of epochs for trainng.

  • patience : int (default = 15)

    Number of consecutive epochs without improvement before performing early stopping.

  • lr : float (default = 0.02)

    Initial learning rate used for training. As mentionned in the original paper, a large initial learning of 0.02 with decay is a good option.

  • clip_value : float (default None)

    If a float is given this will clip the gradient at clip_value.

  • lambda_sparse : float (default = 1e-3)

    This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help.

  • model_name : str (default = 'DQTabNet')

    Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models.

  • saving_path : str (default = './')

    Path defining where to save models.

  • scheduler_fn : torch.optim.lr_scheduler (default = None)

    Pytorch Scheduler to change learning rates during training.

  • scheduler_params: dict

    Parameters dictionnary for the scheduler_fn. Ex : {"gamma": 0.95, "step_size": 10}

  • verbose : int (default=-1)

    Verbosity for notebooks plots, set to 1 to see every epoch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_tabnet-0.1.1.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_tabnet-0.1.1-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_tabnet-0.1.1.tar.gz.

File metadata

  • Download URL: pytorch_tabnet-0.1.1.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.7.4 Linux/4.15.0-66-generic

File hashes

Hashes for pytorch_tabnet-0.1.1.tar.gz
Algorithm Hash digest
SHA256 12f85435366a086cf4c3f1cad2add46fb58947e62ce71e4d5b00dcbbe6a81d03
MD5 3ae5da53d445e559e4d3f971dabb1c20
BLAKE2b-256 e3f68252ae176647d677f5c357b90c8edc38f469535d02830cf24cb6bd2aaec0

See more details on using hashes here.

File details

Details for the file pytorch_tabnet-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_tabnet-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 12.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.7.4 Linux/4.15.0-66-generic

File hashes

Hashes for pytorch_tabnet-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 603dd8e55b32dd29fd82cee584723fec05c399f6aa72a9300dc3e3735bb13d7a
MD5 2e9b5cad778db03452ad97da114c504f
BLAKE2b-256 94d148076f4d5d3486f2101ed4f5771dd00ec8b34e647c2be5bddc3525f27a98

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page