Tensorflow 2.0 implementation of TabNet of any configuration.
Project description
TabNet for Tensorflow 2.0
A Tensorflow 2.0 port for the paper TabNet: Attentive Interpretable Tabular Learning, whose original codebase is available at https://github.com/google-research/google-research/blob/master/tabnet.
The above image is obtained from the paper, where the model is built of blocks in two stages - one to attend to the input features and anither to construct the output of the model.
Differences from Paper
There are two major differences from the paper and the official implementation.
-
This implementation offers a choice in the normalization method, between the regular
Batch Normalization
from the paper andGroup Normalization
.- It has been observed that the paper uses very large batch sizes to stabilie Batch Normalization and obtain good generalization. An issue with this is computational cost.
- Therefore Group Normalization (with number of groups set as 1, aka Instance Normalization) offers a reasonable alternative which is independent of the batch size.
- One can set
num_groups
to 1 forInstance Normalization
type behaviour, or to -1 forLayer Normalization
type behaviour.
-
This implementation does not strictly need feature columns as input.
- While this model was originally developed for tabulur data, there is no hard requirement for that to be the only type of input it accepts.
- By passing
feature_columns=None
and explicitly specifying the input dimensionality of the data (usingnum_features
), we can get a semi-interpretable result from even image data (after flattening it into a long vector).
Installation
- For latest release branch
$ pip install --upgrade tabnet
- For Master branch.
$ pip install git+https://github.com/titu1994/tf-TabNet.git
As Tensorflow can be used with either a CPU or GPU, the package can be installed with the conditional requirements using [cpu]
or [gpu]
as follows.
$ pip install tabnet[cpu]
$ pip install tabnet[gpu]
Usage
The script tabnet.py
can be imported to yield either the TabNet
building block, or the TabNetClassification
and TabNetRegression
models, which add appropriate heads for the basic TabNet
model. If the classification or regression head is to be customized, it is recommended to compose a new model with the TabNet
as the base of the model.
from tabnet import TabNet, TabNetClassifier
model = TabNetClassifier(feature_list, num_classes, ...)
Stacked TabNets
Regular TabNets can be stacked into various layers, thereby reducing interpretability but improving model capacity.
from tabnet import StackedTabNetClassifier
model = TabNetClassifier(feature_list, num_classes, num_layers, ...)
As the models use custom objects, it is necessary to import custom_objects.py
in an evaluation only script.
Mask Visualization
The masks of the TabNet can be obtained by using the TabNet class properties
feature_selection_masks
: Returns a list of 1 or more masks at intermediate decision steps. Number of masks = number of decision steps - 1aggregate_feature_selection_mask
: Returns a single tensor which is the average activation of the masks over that batch of training samples.
These masks can be obtained as TabNet.feature_selection_masks
. Since the TabNetClassification
and TabNetRegression
models are composed of TabNet
, the masks can be obtained as model.tabnet.*
Mask Generation must be in Eager Execution Mode
Note: Due to autograph, the outputs of the model when using fit()
or predict()
Keras APIs will
generally be graph based Tensors, not EagerTensors. Since the masks are generated inside the Model.call()
method,
it is necessary to force the model to behave in Eager execution mode, not in Graph mode.
Therefore there are two ways to force the model into eager mode:
- Get tensor data samples, and directly
call
the model using this data as below :
x, _ = next(iter(tf_dataset)) # Assuming it generates an (x, y) tuple.
_ = model(x) # This forces eager execution.
- Or another choice is to build a seperate model (but here you will pass the
dynamic=True
flag to the model constructor), load the weights and parameters in this model, and callmodel.predict(x)
. This should also force eager execution mode.
new_model = TabNetClassification(..., dynamic=True)
new_model.load_weights('path/to/weights)')
x, _ = next(iter(tf_dataset)) # Assuming it generates an (x, y) tuple.
model.predict(x)
After the model has been forced into Eager Execution mode, the masks can be visualized in Tensorboard as follows -
writer = tf.summary.create_file_writer("logs/")
with writer.as_default():
for i, mask in enumerate(model.tabnet.feature_selection_masks):
print("Saving mask {} of shape {}".format(i + 1, mask.shape))
tf.summary.image('mask_at_iter_{}'.format(i + 1), step=0, data=mask, max_outputs=1)
writer.flush()
agg_mask = model.tabnet.aggregate_feature_selection_mask
print("Saving aggregate mask of shape", agg_mask.shape)
tf.summary.image("Aggregate Mask", step=0, data=agg_mask, max_outputs=1)
writer.flush()
writer.close()
Requirements
- Tensorflow 2.0+ (1.14+ with V2 compat enabled may be sufficient for 1.x)
- Tensorflow-datasets (Only required for evaluating
train_iris.py
)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tabnet-0.1.6.tar.gz
.
File metadata
- Download URL: tabnet-0.1.6.tar.gz
- Upload date:
- Size: 15.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7b901197cb67435255067cb2bc7de327858970a04db7b35c6f253101cf02f80 |
|
MD5 | 95cf78db4e7ab76a413a54f3280dbd9b |
|
BLAKE2b-256 | 8584470e95d5abe5eb5ece8cd89b89e5ff678e1b9e6e23a00045f8dd315fe6ba |
File details
Details for the file tabnet-0.1.6-py2.py3-none-any.whl
.
File metadata
- Download URL: tabnet-0.1.6-py2.py3-none-any.whl
- Upload date:
- Size: 16.3 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d7ad9eed4265f704c777018d1a1a9bbae05882dd1765fa5471c90bea801ca7b1 |
|
MD5 | dff0539cfd5b4197eae5ec5d858716e9 |
|
BLAKE2b-256 | 4da13ccd47a091b7674fbac8a1bd504bbd941fbbe495657a719faaa701248a21 |