Class activation maps for your PyTorch CNN models
TorchCAM: class activation explorer
Simple way to leverage the class-specific activation of convolutional layers in PyTorch.
Setting your CAM
TorchCAM leverages PyTorch hooking mechanisms to seamlessly retrieve all required information to produce the class activation without additional efforts from the user. Each CAM object acts as a wrapper around your model.
You can find the exhaustive list of supported CAM methods in the documentation, then use it as follows:
# Define your model from torchvision.models import resnet18 model = resnet18(pretrained=True).eval() # Set your CAM extractor from torchcam.cams import SmoothGradCAMpp cam_extractor = SmoothGradCAMpp(model)
Please note that by default, the layer at which the CAM is retrieved is set to the last non-reduced convolutional layer. If you wish to investigate a specific layer, use the
target_layer argument in the constructor.
Retrieving the class activation map
Once your CAM extractor is set, you only need to use your model to infer on your data as usual. If any additional information is required, the extractor will get it for you automatically.
from torchvision.io.image import read_image from torchvision.transforms.functional import normalize, resize, to_pil_image from torchvision.models import resnet18 from torchcam.cams import SmoothGradCAMpp model = resnet18(pretrained=True).eval() cam_extractor = SmoothGradCAMpp(model) # Get your input img = read_image("path/to/your/image.png") # Preprocess it for your chosen model input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Preprocess your data and feed it to the model out = model(input_tensor.unsqueeze(0)) # Retrieve the CAM by passing the class index and the model output activation_map = cam_extractor(out.squeeze(0).argmax().item(), out)
If you want to visualize your heatmap, you only need to cast the CAM to a numpy ndarray:
import matplotlib.pyplot as plt # Visualize the raw CAM plt.imshow(activation_map.numpy()); plt.axis('off'); plt.tight_layout(); plt.show()
Or if you wish to overlay it on your input image:
import matplotlib.pyplot as plt from torchcam.utils import overlay_mask # Resize the CAM and overlay it result = overlay_mask(to_pil_image(img), to_pil_image(activation_map, mode='F'), alpha=0.5) # Display it plt.imshow(result); plt.axis('off'); plt.tight_layout(); plt.show()
You can install the last stable release of the package using pypi as follows:
pip install torchcam
or using conda:
conda install -c frgfm torchcam
Alternatively, if you wish to use the latest features of the project that haven't made their way to a release yet, you can install the package from source:
git clone https://github.com/frgfm/torch-cam.git pip install -e torch-cam/.
This project is developed and maintained by the repo owner, but the implementation was based on the following research papers:
- Learning Deep Features for Discriminative Localization: the original CAM paper
- Grad-CAM: GradCAM paper, generalizing CAM to models without global average pooling.
- Grad-CAM++: improvement of GradCAM++ for more accurate pixel-level contribution to the activation.
- Smooth Grad-CAM++: SmoothGrad mechanism coupled with GradCAM.
- Score-CAM: score-weighting of class activation for better interpretability.
- SS-CAM: SmoothGrad mechanism coupled with Score-CAM.
- IS-CAM: integration-based variant of Score-CAM.
- XGrad-CAM: improved version of Grad-CAM in terms of sensitivity and conservation.
The full package documentation is available here for detailed specifications.
A minimal demo app is provided for you to play with the supported CAM methods!
You will need an extra dependency (Streamlit) for the app to run:
pip install -r demo/requirements.txt
You can then easily run your app in your default browser by running:
streamlit run demo/app.py
An example script is provided for you to benchmark the heatmaps produced by multiple CAM approaches on the same image:
python scripts/cam_example.py --model resnet18 --class-idx 232
All script arguments can be checked using
python scripts/cam_example.py --help
Feeling like extending the range of possibilities of CAM? Or perhaps submitting a paper implementation? Any sort of contribution is greatly appreciated!
You can find a short guide in
CONTRIBUTING to help grow this project!
Distributed under the MIT License. See
LICENSE for more information.
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
|Filename, size||File type||Python version||Upload date||Hashes|
|Filename, size torchcam-0.2.0-py3-none-any.whl (15.6 kB)||File type Wheel||Python version py3||Upload date||Hashes View|
|Filename, size torchcam-0.2.0.tar.gz (16.9 kB)||File type Source||Python version None||Upload date||Hashes View|