Skip to main content

Class activation maps for your PyTorch CNN models

Project description

Torchcam: class activation explorer

License Codacy Badge Build Status codecov Docs Pypi

Simple way to leverage the class-specific activation of convolutional layers in PyTorch.

Table of Contents

Getting started

Prerequisites

  • Python 3.6 (or more recent)
  • pip

Installation

You can install the package using pypi as follows:

pip install torchcam

Usage

You can find a detailed example below to retrieve the CAM of a specific class on a resnet architecture.

import requests
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.models import resnet50
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_pil_image
from torchcam.cams import CAM, GradCAM, GradCAMpp
from torchcam.utils import overlay_mask


# Pretrained imagenet model
model = resnet50(pretrained=True)
# Specify layer to hook and fully connected
conv_layer = 'layer4'

# Hook the corresponding layer in the model
gradcam = GradCAMpp(model, conv_layer)

# Get a dog image
URL = 'https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg'
response = requests.get(URL)

# Forward an image
pil_img = Image.open(BytesIO(response.content), mode='r').convert('RGB')
preprocess = transforms.Compose([
   transforms.Resize((224,224)),
   transforms.ToTensor(),
   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_tensor = preprocess(pil_img)
out = model(img_tensor.unsqueeze(0))

# Select the class index
classes = {int(key):value for (key, value)
          in requests.get('https://s3.amazonaws.com/outcome-blog/imagenet/labels.json').json().items()}
class_idx = 232

# Use the hooked data to compute activation map
activation_maps = gradcam(out, class_idx)
# Convert it to PIL image
# The indexing below means first image in batch
heatmap = to_pil_image(activation_maps[0].cpu().numpy(), mode='F')

# Plot the result
result = overlay_mask(pil_img, heatmap)
plt.imshow(result); plt.axis('off'); plt.title(classes.get(class_idx)); plt.tight_layout; plt.show()

gradcam_sample

Technical roadmap

The project is currently under development, here are the objectives for the next releases:

  • Parallel CAMs: enable batch processing.
  • Benchmark: compare class activation map computations for different architectures.
  • Signature improvement: retrieve automatically the last convolutional layer.
  • Refine RPN: create a region proposal network using CAM.
  • Task transfer: turn a well-trained classifier into an object detector.

Documentation

The full package documentation is available here for detailed specifications. The documentation was built with Sphinx using a theme provided by Read the Docs.

Contributing

Please refer to CONTRIBUTING if you wish to contribute to this project.

Credits

This project is developed and maintained by the repo owner, but the implementation was based on the following precious papers:

License

Distributed under the MIT License. See LICENSE for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcam-0.1.0.tar.gz (8.2 kB view hashes)

Uploaded Source

Built Distribution

torchcam-0.1.0-py3-none-any.whl (7.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page