Skip to main content

Tensorflow 2.0 implementation of TabNet of any configuration.

Project description

TabNet for Tensorflow 2.0

A Tensorflow 2.0 port for the paper TabNet: Attentive Interpretable Tabular Learning, whose original codebase is available at https://github.com/google-research/google-research/blob/master/tabnet.

The above image is obtained from the paper, where the model is built of blocks in two stages - one to attend to the input features and anither to construct the output of the model.

Differences from Paper

There are two major differences from the paper and the official implementation.

  1. This implementation offers a choice in the normalization method, between the regular Batch Normalization from the paper and Group Normalization.

    • It has been observed that the paper uses very large batch sizes to stabilie Batch Normalization and obtain good generalization. An issue with this is computational cost.
    • Therefore Group Normalization (with number of groups set as 1, aka Instance Normalization) offers a reasonable alternative which is independent of the batch size.
    • One can set num_groups to 1 for Instance Normalization type behaviour, or to -1 for Layer Normalization type behaviour.
  2. This implementation does not strictly need feature columns as input.

    • While this model was originally developed for tabulur data, there is no hard requirement for that to be the only type of input it accepts.
    • By passing feature_columns=None and explicitly specifying the input dimensionality of the data (using num_features), we can get a semi-interpretable result from even image data (after flattening it into a long vector).

Installation

  • For latest release branch
$ pip install --upgrade tabnet
  • For Master branch.
$ pip install git+https://github.com/titu1994/tf-TabNet.git

As Tensorflow can be used with either a CPU or GPU, the package can be installed with the conditional requirements using [cpu] or [gpu] as follows.

$ pip install tabnet[cpu]
$ pip install tabnet[gpu]

Usage

The script tabnet.py can be imported to yield either the TabNet building block, or the TabNetClassification and TabNetRegression models, which add appropriate heads for the basic TabNet model. If the classification or regression head is to be customized, it is recommended to compose a new model with the TabNet as the base of the model.

from tabnet import TabNet, TabNetClassifier

model = TabNetClassifier(feature_list, num_classes, ...)

Stacked TabNets

Regular TabNets can be stacked into various layers, thereby reducing interpretability but improving model capacity.

from tabnet import StackedTabNetClassifier

model = TabNetClassifier(feature_list, num_classes, num_layers, ...)

As the models use custom objects, it is necessary to import custom_objects.py in an evaluation only script.

Mask Visualization

The masks of the TabNet can be obtained by using the TabNet class properties

  • feature_selection_masks: Returns a list of 1 or more masks at intermediate decision steps. Number of masks = number of decision steps - 1
  • aggregate_feature_selection_mask: Returns a single tensor which is the average activation of the masks over that batch of training samples.

These masks can be obtained as TabNet.feature_selection_masks. Since the TabNetClassification and TabNetRegression models are composed of TabNet, the masks can be obtained as model.tabnet.*

Mask Generation must be in Eager Execution Mode

Note: Due to autograph, the outputs of the model when using fit() or predict() Keras APIs will generally be graph based Tensors, not EagerTensors. Since the masks are generated inside the Model.call() method, it is necessary to force the model to behave in Eager execution mode, not in Graph mode.

Therefore there are two ways to force the model into eager mode:

  1. Get tensor data samples, and directly call the model using this data as below :
x, _ = next(iter(tf_dataset))  # Assuming it generates an (x, y) tuple.
_ = model(x)  # This forces eager execution.
  1. Or another choice is to build a seperate model (but here you will pass the dynamic=True flag to the model constructor), load the weights and parameters in this model, and call model.predict(x). This should also force eager execution mode.
new_model = TabNetClassification(..., dynamic=True)
new_model.load_weights('path/to/weights)')

x, _ = next(iter(tf_dataset))  # Assuming it generates an (x, y) tuple.
model.predict(x)

After the model has been forced into Eager Execution mode, the masks can be visualized in Tensorboard as follows -

writer = tf.summary.create_file_writer("logs/")
with writer.as_default():
    for i, mask in enumerate(model.tabnet.feature_selection_masks):
        print("Saving mask {} of shape {}".format(i + 1, mask.shape))
        tf.summary.image('mask_at_iter_{}'.format(i + 1), step=0, data=mask, max_outputs=1)
        writer.flush()

    agg_mask = model.tabnet.aggregate_feature_selection_mask
    print("Saving aggregate mask of shape", agg_mask.shape)
    tf.summary.image("Aggregate Mask", step=0, data=agg_mask, max_outputs=1)
    writer.flush()
writer.close()

Requirements

  • Tensorflow 2.0+ (1.14+ with V2 compat enabled may be sufficient for 1.x)
  • Tensorflow-datasets (Only required for evaluating train_iris.py)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabnet-0.1.6.tar.gz (15.9 kB view details)

Uploaded Source

Built Distribution

tabnet-0.1.6-py2.py3-none-any.whl (16.3 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file tabnet-0.1.6.tar.gz.

File metadata

  • Download URL: tabnet-0.1.6.tar.gz
  • Upload date:
  • Size: 15.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.4

File hashes

Hashes for tabnet-0.1.6.tar.gz
Algorithm Hash digest
SHA256 e7b901197cb67435255067cb2bc7de327858970a04db7b35c6f253101cf02f80
MD5 95cf78db4e7ab76a413a54f3280dbd9b
BLAKE2b-256 8584470e95d5abe5eb5ece8cd89b89e5ff678e1b9e6e23a00045f8dd315fe6ba

See more details on using hashes here.

File details

Details for the file tabnet-0.1.6-py2.py3-none-any.whl.

File metadata

  • Download URL: tabnet-0.1.6-py2.py3-none-any.whl
  • Upload date:
  • Size: 16.3 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.4

File hashes

Hashes for tabnet-0.1.6-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 d7ad9eed4265f704c777018d1a1a9bbae05882dd1765fa5471c90bea801ca7b1
MD5 dff0539cfd5b4197eae5ec5d858716e9
BLAKE2b-256 4da13ccd47a091b7674fbac8a1bd504bbd941fbbe495657a719faaa701248a21

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page