Skip to main content

Automated implementation of Google TabNet.

Project description

Auto-Tabnet

Auto-TabNet is the implementation of Google's TabNet model using dreamquark-ai's pytorch implementation along with hyperparameter search with Optuna.

Google's TabNet was proposed in 2019 with the idea of effectively using deep neural networks for tabular data.

TabNet is a complex model composed of a feature transformer, attentive transformer, and feature masking, that soft feature selection with controllable sparsity in end-to-end learning. The reason for the high performance of TabNet is that it focuses on the most important features that have been considered by the Attentive Transformer. The Attentive Transformer performs feature selection to select which model features to reason from at each step in the model, and a Feature Transformer processes feature into more useful representations and learn complex data patterns, which improve interpretability and help it learn more accurate models.

Requirements

python 3.7 >
pip (python package manager)

Installation

With pip:

pip install auto-tabnet

Source Code

If you want to use it locally within a pip virtualenv:

  • Clone the repository
git clone https://github.com/Femme-js/auto-tabnet.git
  • Create a pip virtual environment.
virtualenv env
  • Install the dependencies from requirements.txt file.
pip install -r requirements.txt

Usage

from auto_tabnet import AutoTabnetClassifier

clf = AutoTabnetClassifier(X, y, X_test)

To get the prediction on test data.

results = clf.predict()

To get the auc_roc_score:

results = clf.get_roc_auc_score()

To get the best hyperparamters tuned by optuna:

results = clf.get_best_params()

The targets on y_train should contain a unique type (e.g. they must all be strings or integers).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

auto_tabnet-0.0.1.tar.gz (16.0 kB view details)

Uploaded Source

Built Distribution

auto_tabnet-0.0.1-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file auto_tabnet-0.0.1.tar.gz.

File metadata

  • Download URL: auto_tabnet-0.0.1.tar.gz
  • Upload date:
  • Size: 16.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.12

File hashes

Hashes for auto_tabnet-0.0.1.tar.gz
Algorithm Hash digest
SHA256 57d874c13c96c1f7f3298dad15374ed194f7df087c073b43237fd5ad7f3a023c
MD5 ec88cbb8e7f4e74e809c11ee0b87979d
BLAKE2b-256 83f8e15aee21380791324ea90127b1a5496aac0609210525659afc9b373d275f

See more details on using hashes here.

File details

Details for the file auto_tabnet-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: auto_tabnet-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 16.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.12

File hashes

Hashes for auto_tabnet-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7d4f7dbd1774fe4dc9b2d54944f236eef96f51dc4df7236a8536343f23e89d23
MD5 09ebf563dfab765559631ad0ca3aca7d
BLAKE2b-256 9b7c10280e33c098f214587fdc52b7740f6f1d0e946ea37e29d4b14c20a27ece

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page