Skip to main content

Interpretability Callbacks for Tensorflow 2.0

Project description

tf-explain

Pypi Version Build Status Documentation Status Python Versions Tensorflow Versions Code style: black

tf-explain implements interpretability methods as Tensorflow 2.0 callbacks to ease neural network's understanding.

Installation

tf-explain is available on PyPi as an alpha release. To install it:

virtualenv venv -p python3.6
pip install tf-explain

tf-explain is compatible with Tensorflow 2. It is not declared as a dependency to let you choose between CPU and GPU versions. Additionally to the previous install, run:

# For CPU version
pip install tensorflow==2.0.0-beta1
# For GPU version
pip install tensorflow-gpu==2.0.0-beta1

Available Methods

  1. Activations Visualization
  2. Occlusion Sensitivity
  3. Grad CAM (Class Activation Maps)
  4. SmoothGrad

Activations Visualization

Visualize how a given input comes out of a specific activation layer

from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback

model = [...]

callbacks = [
    ActivationsVisualizationCallback(
        validation_data=(x_val, y_val),
        layers_name=["activation_1"],
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity

Visualize how parts of the image affects neural network's confidence by occluding parts iteratively

from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback

model = [...]

callbacks = [
    OcclusionSensitivityCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        patch_size=4,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity for Tabby class (stripes differentiate tabby cat from other ImageNet cat classes)

Grad CAM

Visualize how parts of the image affects neural network's output by looking into the activation maps

From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        layer_name="activation_1",
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

SmoothGrad

Visualize stabilized gradients on the inputs towards the decision

From SmoothGrad: removing noise by adding noise

from tf_explain.callbacks.smoothgrad import SmoothGradCallback

model = [...]

callbacks = [
    SmoothGradCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        num_samples=20,
        noise=1.,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Visualizing the results

When you use the callbacks, the output files are created in the logs directory.

You can see them in tensorboard with the following command: tensorboard --logdir logs

Roadmap

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tf-explain-0.0.2a0.tar.gz (14.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tf_explain-0.0.2a0-py3-none-any.whl (28.2 kB view details)

Uploaded Python 3

File details

Details for the file tf-explain-0.0.2a0.tar.gz.

File metadata

  • Download URL: tf-explain-0.0.2a0.tar.gz
  • Upload date:
  • Size: 14.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.8

File hashes

Hashes for tf-explain-0.0.2a0.tar.gz
Algorithm Hash digest
SHA256 582a5b00fe00bae653c33be363b909c9d6a49ac79d4fb4476718823d8c684e37
MD5 88d5a531e221239dc1371de16147a667
BLAKE2b-256 a33f22186426aa9c8c6d0580211550a1014c4727e33d7d3a067e2323ecea0cbf

See more details on using hashes here.

File details

Details for the file tf_explain-0.0.2a0-py3-none-any.whl.

File metadata

  • Download URL: tf_explain-0.0.2a0-py3-none-any.whl
  • Upload date:
  • Size: 28.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.6.8

File hashes

Hashes for tf_explain-0.0.2a0-py3-none-any.whl
Algorithm Hash digest
SHA256 8dcefe762ab9ed70e6999953bdc3a6181540de549cec871b45231849d7319f36
MD5 731568d430d3a4856b85f3f02035bafb
BLAKE2b-256 af54cb2869ba0da4cf282af546441bab3e1d7a611bfc8811acd2345b41295038

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page