Interpreting deep learning models in PyTorch.
Project description
PyTorch Interpret
A simple to use PyTorch library for interpreting your deep learning results, using both visualisations and attributions. Inspired by TensorFlow Lucid.
Installation
Install from PyPI:
pip install interpret-pytorch
Or, install the latest code from GitHub:
pip install git+https://github.com/ttumiel/interpret
Dependencies
interpret requires a working installation of PyTorch.
Contents
Tutorials
Run the tutorials in the browser using Google Colab.
| Tutorial | Link |
|---|---|
Introduction to interpret |
|
| Visualisation Tutorial | |
| Miscellaneous Methods Tutorial |
Visualisation
Visualisation is a technique that generates inputs that optimise a particular objective within a trained network. By using visualisations, we can understand what it is that a network is looking for. For an in-depth explanation of visualisation, see Feature Visualisation.
Quickstart
Generating visualisations is done by loading a trained network, selecting the objective to optimise for and running the optimisation. An example using a pretrained network from torchvision is shown.
from interpret import OptVis
import torchvision
# Get the PyTorch neural network
network = torchvision.models.vgg11(pretrained=True)
# Select a layer from the network. Use get_layer_names()
# to see a list of layer names and sizes.
layer = 'features/18'
channel = 12
# Create an OptVis object from a PyTorch model
optvis = OptVis.from_layer(network, layer=layer, channel=channel)
# Create visualisation
optvis.vis()
Parameterisations
Images can be parameterised in several different ways. As long as the parameterisation is differentiable, the input can be optimised for a particular layer. For code examples, see the Visualisation Tutorial Notebook.
The default parameterisation is in spatial and colour decorrelated space.
We can also parameterise in regular pixel space but the visualisations tend to be worse.
Another parameterisation is a compositional pattern producing network (CPPN) which can generate infinite resolution images that have the effect of "light paintings."
Objectives
The objective on which to optimise can also be manipulated to create different visualisations. We can add objectives together to get compound objectives or negate them to get negative neurons. See the Visualisation Tutorial Notebook for examples.
Layer Objective
A LayerObjective can be created easily using the from_layer OptVis class method. In this function, we can choose the layer, channel and neuron to optimise for. Here we can optimise for a particular neuron:
We can also manually create two objectives and add them together to get a compound objective:
Or we can find the negated objective that minimises a particular neuron:
Layer objectives are fairly flexible. You can select any layer in the network and capture the output of that particular layer. We can visualise the last layer of the network, generating class visualisations of the different classes in ImageNet.
Deep Dream Objective
The deep dream objective optimises for "interestingness" across an entire layer. We can create this objective from an input image and select a layer using the from_dream class method.
Attribution
Network attribution is done by feeding a particular input into the trained network and generating a saliency map that shows the parts of the image that the network activates highly on.
Quickstart
from interpret import Gradcam, norm
from PIL import Image
import torchvision
network = torchvision.models.vgg11(pretrained=True)
input_img = Image.open('image.jpg')
# Normalise the input image and turn it into a tensor
input_data = norm(input_img)
# Select the class that we are attributing to
class_number = 207
# Choose a layer for Grad-CAM
layer = 'features/20'
# Generate a Grad-CAM attribution map
saliency_map = Gradcam(network, input_data, im_class=class_number, layer=layer)
saliency_map.show()
Miscellaneous Interpretations
Included in interpret are a few additional interpretation methods that don't neatly fit into visualisation or attribution methods.
Plot Top Losses
Plot the inputs that result in the largest loss. Useful for identifying where your network is most unsure or where the inputs actually don't fit the label given (a mislabelled image). You can also enable a Grad-CAM attribution overlay for each image so that you can tell where the network is looking.
Plot Confusion Matrix
Plot a confusion matrix for a multi-class classification or binned regression objective.
Plot Dataset Examples
Plot some dataset examples that maximise a particular LayerObjective from the visualisation objectives described above. Useful for identifying clear examples of what the network is looking for in a particular visualisation using real examples.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file interpret-pytorch-0.2.1.tar.gz.
File metadata
- Download URL: interpret-pytorch-0.2.1.tar.gz
- Upload date:
- Size: 40.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
26d0890c78a2ad605be5bd83dec202789ad4b0512b65fa121787a243a53826d6
|
|
| MD5 |
3ba925db509375b52c67a6b7416d3616
|
|
| BLAKE2b-256 |
a41aa1a32768f57cecf9e404d13f5bbda93df96b8cd86cd3df95d8a2360a556f
|
File details
Details for the file interpret_pytorch-0.2.1-py3-none-any.whl.
File metadata
- Download URL: interpret_pytorch-0.2.1-py3-none-any.whl
- Upload date:
- Size: 48.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ef8cb80cc31fc9293d4af22823f07535c050e50dd4d20ab7b729a53e1c4a030
|
|
| MD5 |
2e0b40a67714a157a74c12bc11e1dfba
|
|
| BLAKE2b-256 |
eac0898a60e80037b972f973e1f181b3e7ba6697c3204f2644b3e2aab8206e3b
|