Skip to main content

A tf2 keras implementation of tabnet

Project description

TF2 Keras implementation of TabNet

TabNet is a novel deep learning architecture for tabular data. TabNet performs reasoning in multiple decision steps and using sequential attention to select which features to use at which decision step. You can find more information about it in the original research paper.

Installation

$ pip install tabnet_keras

Usage

from tabnet_keras import TabNetRegressor, TabNetClassifier

tabnet_params = {
    "decision_dim": 16,
    "attention_dim": 16,
    "n_steps": 3,
    "n_shared_glus": 2,
    "n_dependent_glus": 2,
    "relaxation_factor": 1.3,
    "epsilon": 1e-15,
    "momentum": 0.98,
    "mask_type": "sparsemax", # can be 'sparsemax' or 'softmax'
    "lambda_sparse": 1e-3, 
    "virtual_batch_splits": 8 #number of splits for ghost batch normalization, ideally should evenly divide the batch_size
}

### Regression 
model = TabNetRegressor(n_regressors = 1, **tabnet_params)
model.compile(loss = 'mean_squared_error', optimizer = tf.keras.optimizers.Adam(0.01), 
             metrics = [tf.keras.metrics.RootMeanSquaredError()])
model.fit(X, y, epochs = 100, batch_size = 1024)

### Classification
model = TabNetClassifier(n_classes = 10, out_activation = None, **tabnet_params)
model.compile(loss = 'categorical_crossentropy', optimizer = tf.keras.optimizers.Adam(0.01))
model.fit(X, y, epochs = 100, batch_size = 1024)

Acknowledgment

Most of the code is taken with minor changes from this repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabnet_keras-1.2.0.tar.gz (11.9 kB view details)

Uploaded Source

Built Distribution

tabnet_keras-1.2.0-py3-none-any.whl (16.9 kB view details)

Uploaded Python 3

File details

Details for the file tabnet_keras-1.2.0.tar.gz.

File metadata

  • Download URL: tabnet_keras-1.2.0.tar.gz
  • Upload date:
  • Size: 11.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.12

File hashes

Hashes for tabnet_keras-1.2.0.tar.gz
Algorithm Hash digest
SHA256 1b975913cc85bd1f9d908d5cc3673202a25e73511c89ae80e697bc14565819c7
MD5 13b4708141860da1d3dec14da191d637
BLAKE2b-256 95ded0287064c8f499788796efe13856389b75f74d9645456c9e36efacc19205

See more details on using hashes here.

File details

Details for the file tabnet_keras-1.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for tabnet_keras-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 784a53a8266ff06bf9a25c8097844a6faf98e91b963d4bc991ff5a993d22e2cc
MD5 5b1b8eccc5812e3f1fd9e0c8123405c4
BLAKE2b-256 0009f821044b11e4550b79fab25428342f7087b1075d0d1ac99bb99ab2062dcd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page