Skip to main content

PyTorch implementation of TabNet

Project description

TabNet: Attentive Interpretable Tabular Learning

PyPI version Python versions License OS OS OS codecov Ruff Last Commit

TabNet is a deep learning architecture designed specifically for tabular data, combining interpretability and high predictive performance. This package provides a modern, maintained implementation of TabNet in PyTorch, supporting classification, regression, multitask learning, and unsupervised pretraining.

Installation

Install TabNet using pip:

pip install pytorch-tabnet2

What is TabNet?

TabNet is an interpretable neural network architecture for tabular data, introduced by Arik & Pfister (2019). It uses sequential attention to select which features to reason from at each decision step, enabling both high performance and interpretability. TabNet learns sparse feature masks, allowing users to understand which features are most important for each prediction. The method is particularly effective for structured/tabular datasets where traditional deep learning models often underperform compared to tree-based methods.

Key aspects of TabNet:

  • Attentive Feature Selection: At each step, TabNet learns which features to focus on, improving both accuracy and interpretability.
  • Interpretable Masks: The model produces feature masks that highlight the importance of each feature for individual predictions.
  • End-to-End Learning: Supports classification, regression, multitask, and unsupervised pretraining tasks.

What problems does pytorch-tabnet handle?

  • TabNetClassifier : binary classification and multi-class classification problems.
  • TabNetRegressor : simple and multi-task regression problems.
  • TabNetMultiTaskClassifier: multi-task multi-classification problems.
  • MultiTabNetRegressor: multi-task regression problems, which is basically TabNetRegressor with multiple targets.

Usage

Documentation

Basic Examples

Classification

import numpy as np
from pytorch_tabnet import TabNetClassifier

# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)
X_valid = np.random.rand(20, 10)
y_valid = np.random.randint(0, 2, 20)
X_test = np.random.rand(10, 10)

clf = TabNetClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)])
preds = clf.predict(X_test)
print('Predictions:', preds)

Regression

import numpy as np
from pytorch_tabnet import TabNetRegressor

# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.rand(100).reshape(-1, 1)
X_valid = np.random.rand(20, 10)
y_valid = np.random.rand(20).reshape(-1, 1)
X_test = np.random.rand(10, 10)

reg = TabNetRegressor()
reg.fit(X_train, y_train, eval_set=[(X_valid, y_valid)])
preds = reg.predict(X_test)
print('Predictions:', preds)

Multi-task Classification

import numpy as np
from pytorch_tabnet import TabNetMultiTaskClassifier

# Generate dummy data
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, (100, 3))  # 3 tasks
X_valid = np.random.rand(20, 10)
y_valid = np.random.randint(0, 2, (20, 3))
X_test = np.random.rand(10, 10)

clf = TabNetMultiTaskClassifier()
clf.fit(X_train, y_train, eval_set=[(X_valid, y_valid)])
preds = clf.predict(X_test)
print('Predictions:', preds)

See the nbs/ folder for more complete examples and notebooks.

Further Reading

License & Credits

  • Original implementation and research by DreamQuark team
  • Maintained and improved by Daniel Avdar and contributors
  • See LICENSE for details

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_tabnet2-4.6.0.tar.gz (40.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_tabnet2-4.6.0-py3-none-any.whl (70.8 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_tabnet2-4.6.0.tar.gz.

File metadata

  • Download URL: pytorch_tabnet2-4.6.0.tar.gz
  • Upload date:
  • Size: 40.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for pytorch_tabnet2-4.6.0.tar.gz
Algorithm Hash digest
SHA256 1bd223655932a513a91b23a50a7044b60656507225daabda65689a18e549b5b0
MD5 364f2adf2b7ba5723eef883e4b78ff6d
BLAKE2b-256 14a1cfef6063f17bc681988a277d3644eec2c05aa16d089f2dece2bce98b6262

See more details on using hashes here.

File details

Details for the file pytorch_tabnet2-4.6.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_tabnet2-4.6.0-py3-none-any.whl
  • Upload date:
  • Size: 70.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for pytorch_tabnet2-4.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f2cb42810258baf04cc0940164814b26ba8d9251557604ba1253e8a904364764
MD5 9cb0c36b7025dab56be0854adddba927
BLAKE2b-256 81340d27f7770e60b28401c853e9938c42bc1d1bb704c3bb49412d5d49707d3a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page