Skip to main content

No project description provided

Project description

README

TabNet : Attentive Interpretable Tabular Learning

This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf.

Installation

You can install using pip by running: pip install tabnet

If you wan to use it locally within a docker container:

git clone git@github.com:dreamquark-ai/tabnet.git

cd tabnet to get inside the repository

make start to build and get inside the container

poetry install to install all the dependencies, including jupyter

make notebook inside the same terminal

You can then follow the link to a jupyter notebook with tabnet installed.

GPU version is available and should be working but is not supported yet.

How to use it?

The implementation makes it easy to try different architectures of TabNet. All you need is to change the network parameters and training parameters. All parameters are quickly describe bellow, to get a better understanding of what each parameters do please refer to the orginal paper.

You can also get comfortable with the code works by playing with the notebooks tutorials for adult census income dataset and forest cover type dataset.

Network parameters

  • input_dim : int

    Number of initial features of the dataset

  • output_dim : int

    Size of the desired output. Ex :

    • 1 for regression task
    • 2 for binary classification
    • N > 2 for multiclass classifcation
  • nd : int

    Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. Values typically range from 8 to 64.

  • na : int

    Width of the attention embedding for each mask. According to the paper nd=na is usually a good choice.

  • n_steps : int Number of steps in the architecture (usually between 3 and 10)

  • gamma : float This is the coefficient for feature reusage in the masks. A value close to 1 will make mask selection least correlated between layers. Values range from 1.0 to 2.0

  • cat_idxs : list of int

    List of categorical features indices.

  • cat_emb_dim : list of int

    List of embeddings size for each categorical features.

  • n_independent : int

    Number of independent Gated Linear Units layers at each step. Usual values range from 1 to 5 (default=2)

  • n_shared : int

    Number of shared Gated Linear Units at each step Usual values range from 1 to 5 (default=2)

  • virtual_batch_size : int

    Size of the mini batches used for Ghost Batch Normalization

Training parameters

  • max_epochs : int (default = 200)

    Maximum number of epochs for trainng.

  • patience : int (default = 15)

    Number of consecutive epochs without improvement before performing early stopping.

  • lr : float (default = 0.02)

    Initial learning rate used for training. As mentionned in the original paper, a large initial learning of 0.02 with decay is a good option.

  • clip_value : float (default None)

    If a float is given this will clip the gradient at clip_value.

  • lambda_sparse : float (default = 1e-3)

    This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help.

  • model_name : str (default = 'DQTabNet')

    Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models.

  • saving_path : str (default = './')

    Path defining where to save models.

  • scheduler_fn : torch.optim.lr_scheduler (default = None)

    Pytorch Scheduler to change learning rates during training.

  • scheduler_params: dict

    Parameters dictionnary for the scheduler_fn. Ex : {"gamma": 0.95, "step_size": 10}

  • verbose : int (default=-1)

    Verbosity for notebooks plots, set to 1 to see every epoch.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_tabnet-0.1.0.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

pytorch_tabnet-0.1.0-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_tabnet-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch_tabnet-0.1.0.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.7.4 Linux/4.15.0-66-generic

File hashes

Hashes for pytorch_tabnet-0.1.0.tar.gz
Algorithm Hash digest
SHA256 aa9d5f33308494f8578ae86ca37662093ba1d23d67ac79d5462e5d0f820520ee
MD5 67c890cea42c6c48d325f30c20e3fb77
BLAKE2b-256 8474c3fc722ac74286b1ea33fca1007f5f3e6d16a8445b59e6c9ad19f0739089

See more details on using hashes here.

File details

Details for the file pytorch_tabnet-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_tabnet-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 12.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.7.4 Linux/4.15.0-66-generic

File hashes

Hashes for pytorch_tabnet-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 64dadbd92c624bea12ea7ba81cd8c285e8ad5e9b7ed1971436654352f845e608
MD5 57cfe56e6c633b33f05eb5341edc9fe7
BLAKE2b-256 f6fb85e2d723867374139223b281fb7e7d257ad484cd35420289ffc6c3a8dc67

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page