Skip to main content

Class activation maps for your PyTorch CNN models

Project description

TorchCAM: class activation explorer

License Codacy Badge Build Status codecov Docs Pypi

Simple way to leverage the class-specific activation of convolutional layers in PyTorch.

gradcam_sample

Quick Tour

Setting your CAM

TorchCAM leverages PyTorch hooking mechanisms to seamlessly retrieve all required information to produce the class activation without additional efforts from the user. Each CAM object acts as a wrapper around your model.

You can find the exhaustive list of supported CAM methods in the documentation, then use it as follows:

# Define your model
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval()

# Set your CAM extractor
from torchcam.cams import SmoothGradCAMpp
cam_extractor = SmoothGradCAMpp(model)

Please note that by default, the layer at which the CAM is retrieved is set to the last non-reduced convolutional layer. If you wish to investigate a specific layer, use the target_layer argument in the constructor.

Retrieving the class activation map

Once your CAM extractor is set, you only need to use your model to infer on your data as usual. If any additional information is required, the extractor will get it for you automatically.

from torchvision.io.image import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchvision.models import resnet18
from torchcam.cams import SmoothGradCAMpp

model = resnet18(pretrained=True).eval()
cam_extractor = SmoothGradCAMpp(model)
# Get your input
img = read_image("path/to/your/image.png")
# Preprocess it for your chosen model
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# Preprocess your data and feed it to the model
out = model(input_tensor.unsqueeze(0))
# Retrieve the CAM by passing the class index and the model output
activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)

If you want to visualize your heatmap, you only need to cast the CAM to a numpy ndarray:

import matplotlib.pyplot as plt
# Visualize the raw CAM
plt.imshow(activation_map.numpy()); plt.axis('off'); plt.tight_layout(); plt.show()

raw_heatmap

Or if you wish to overlay it on your input image:

import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask

# Resize the CAM and overlay it
result = overlay_mask(to_pil_image(img), to_pil_image(activation_map, mode='F'), alpha=0.5)
# Display it
plt.imshow(result); plt.axis('off'); plt.tight_layout(); plt.show()

overlayed_heatmap

Setup

Python 3.6 (or higher) and pip/conda are required to install TorchCAM.

Stable release

You can install the last stable release of the package using pypi as follows:

pip install torchcam

or using conda:

conda install -c frgfm torchcam

Developer installation

Alternatively, if you wish to use the latest features of the project that haven't made their way to a release yet, you can install the package from source:

git clone https://github.com/frgfm/torch-cam.git
pip install -e torch-cam/.

CAM Zoo

This project is developed and maintained by the repo owner, but the implementation was based on the following research papers:

What else

Documentation

The full package documentation is available here for detailed specifications.

Demo app

A minimal demo app is provided for you to play with the supported CAM methods!

You will need an extra dependency (Streamlit) for the app to run:

pip install -r demo/requirements.txt

You can then easily run your app in your default browser by running:

streamlit run demo/app.py

torchcam_demo

Example script

An example script is provided for you to benchmark the heatmaps produced by multiple CAM approaches on the same image:

python scripts/cam_example.py --model resnet18 --class-idx 232

gradcam_sample

All script arguments can be checked using python scripts/cam_example.py --help

Contributing

Feeling like extending the range of possibilities of CAM? Or perhaps submitting a paper implementation? Any sort of contribution is greatly appreciated!

You can find a short guide in CONTRIBUTING to help grow this project!

License

Distributed under the MIT License. See LICENSE for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcam-0.2.0.tar.gz (16.9 kB view details)

Uploaded Source

Built Distribution

torchcam-0.2.0-py3-none-any.whl (15.6 kB view details)

Uploaded Python 3

File details

Details for the file torchcam-0.2.0.tar.gz.

File metadata

  • Download URL: torchcam-0.2.0.tar.gz
  • Upload date:
  • Size: 16.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10

File hashes

Hashes for torchcam-0.2.0.tar.gz
Algorithm Hash digest
SHA256 b02ba7a68f37296fdd6fe772a0782aebed6afa73b962ae1f9204e9b968ad6fa8
MD5 4689d60aa1aa5e95bd1d7805a55fd4ef
BLAKE2b-256 8981bb87e52e580a8b34a64c4b3600d947ffeeab01b45ec7817af872018b191a

See more details on using hashes here.

File details

Details for the file torchcam-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchcam-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 15.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10

File hashes

Hashes for torchcam-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 99cca9e8b2c5868d3fe9bbb93fe861b9f35a00ff51200d0e8eedd0b9c604f32f
MD5 dc2a48b0385d8f35912387753dab75f7
BLAKE2b-256 4b427522d13c98ec765ee308c7d909ad303b94c0a9058cdfbdc0478025745231

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page