Skip to main content

GradCAM for torchvision, timm and huggingface model.

Project description

easy_gradcam

A lightweight tool to generate Grad-CAM visualizations for image classification models. It supports popular backbones such as ResNet, Vision Transformers (ViT), and Hugging Face Transformers.

This package is built on top of PyTorch for deep learning model implementation, and provides visualization utilities powered by Matplotlib and Seaborn. It is designed to help users easily train, evaluate, and visualize results with clear and customizable plots.


Installation

pip install easy_gradcam

Quick Start

1. Import dependencies

# === data preprocess === 
import cv2
from PIL import Image
from pathlib import Path
import torchvision.transforms as transforms

# === model(maybe choose one?) ===
import torchvision.models as models
import timm
from transformers import AutoModelForImageClassification

# === this visualization tool ===
from easy_gradcam.classification import EasyGradCAM
from easy_gradcam.visualization import save_heatmap, save_mix_heatmap

2. Load a model

You can use different backbones:

# Example 1: ResNet-50 (from torchvision)
model = models.resnet50(pretrained=True)   # targets: "layer4"

# Example 2: ViT (from timm)
model = timm.create_model("vit_base_patch16_224_miil", pretrained=True)   # targets: "blocks.10"

# Example 3: DINOv2 (from Hugging Face)
model = AutoModelForImageClassification.from_pretrained(
    "facebook/dinov2-small-imagenet1k-1-layer"
) # targets: "dinov2.encoder.layer.11"

# Example 4: Your own model
model = CustomModel(...)

model.eval()

2.1 Identify target layers

To find the correct target layer names for your model, you can print the model architecture:

print(model)

3. Prepare an image

img_path = Path("./exp1.jpg")
# 1. Use OpenCV to read an image (choose one)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# 2. Use Pillow to read an image (choose one)
img = Image.open(img_path).convert("RGB")
img = np.array(img)

totensor = transforms.ToTensor()
resize = transforms.Resize((224, 224))
normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])

t = totensor(img)
t = resize(t)
t = normalize(t)
t = t.unsqueeze(0)  # add batch dimension

4. Compute Grad-CAM

# Example 1: single target layer (choose one)
gradcam = EasyGradCAM(model, targets="dinov2.encoder.layer.11")

# Example 2: multiple target layers (choose one)
gradcam = EasyGradCAM(model, targets=["dinov2.encoder.layer.10", "dinov2.encoder.layer.11"])  

# Extract features and gradients
feats, grads = gradcam.cal_feat_and_grad(t)

# Generate heatmaps
heats = gradcam.cal_heats(img, feats, grads)

5. Save results

output_path = Path("results")
output_path.mkdir(parents=True, exist_ok=True)

for i in range(len(heats)):
    for name in heats[i]:
        # Save plain heatmap
        save_heatmap(
            save_path=f"{output_path}/{img_path}-{i}-{name}.jpg",
            heat=heats[i][name],
            cmap="jet",
            title="grad-cam"
        )

        # Save overlay with original image
        save_mix_heatmap(
            save_path=f"{output_path}/{img_path}-{i}-{name}-mix.jpg",
            heat=heats[i][name],
            ori_img=img,
            cmap="jet"
        )

Example Output

  • results/exp1-0-dinov2.encoder.layer.11.jpg: heatmap only exp1
  • results/exp1-0-dinov2.encoder.layer.11-mix.jpg: heatmap overlay on the input image exp2

Notes

  • Make sure the target layer you pass matches the internal structure of the model.
  • Pretrained models from torchvision, timm, and Hugging Face are supported.
  • Heatmaps are saved as .jpg files in the results/ directory.

Bugs/Requests

Please send bug reports and feature requests through github issue tracker.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easy_gradcam-0.0.5.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

easy_gradcam-0.0.5-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file easy_gradcam-0.0.5.tar.gz.

File metadata

  • Download URL: easy_gradcam-0.0.5.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for easy_gradcam-0.0.5.tar.gz
Algorithm Hash digest
SHA256 e4edaf420cbe496f3a14f3e2d18a74218d3f35f77a8871ef4df206e6649775aa
MD5 2280f0bb37fb97796d3fade9cc5a5f09
BLAKE2b-256 154b3eac3f122113e804044d6d2dca7f776b90b148cd540abdad5cb022a4c165

See more details on using hashes here.

File details

Details for the file easy_gradcam-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: easy_gradcam-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for easy_gradcam-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 3758bc118c8af8a0363be968a9f4e461beccfed4f34285a07167143e1e442083
MD5 da3f552e200b96fba49b35d2d6422ad2
BLAKE2b-256 97398301bbb694a9bbcfdddc9e4ce019dc6f797db45d28d43a9a17a61b542437

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page