Skip to main content

GradCAM for torchvision, timm and huggingface model.

Project description

easy_gradcam

A lightweight tool to generate Grad-CAM visualizations for image classification models. It supports popular backbones such as ResNet, Vision Transformers (ViT), and Hugging Face Transformers.


Installation

pip install easy_gradcam

Quick Start

1. Import dependencies

import cv2
import torchvision.models as models
import torchvision.transforms as transforms
import timm
from transformers import AutoModelForImageClassification
from easy_gradcam.classification import EasyGradCAM
from easy_gradcam.visualization import save_heatmap, save_mix_heatmap

2. Load a model

You can use different backbones:

# Example 1: ResNet-50 (torchvision)
model = models.resnet50(pretrained=True)   # target layer: "layer4"

# Example 2: ViT (timm)
model = timm.create_model("vit_base_patch16_224_miil", pretrained=True)   # target layer: "blocks.10"

# Example 3: Hugging Face (DINOv2)
model = AutoModelForImageClassification.from_pretrained(
    "facebook/dinov2-small-imagenet1k-1-layer"
)
model.eval()

3. Prepare an image

img = cv2.imread("./exp1.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

totensor = transforms.ToTensor()
resize = transforms.Resize((224, 224))
normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])

t = totensor(img)
t = resize(t)
t = normalize(t)
t = t.unsqueeze(0)  # add batch dimension

4. Compute Grad-CAM

gradcam = EasyGradCAM(model, target_layer="dinov2.encoder.layer.11")

# Extract features and gradients
feats, grads = gradcam.cal_feat_and_grad(t)

# Generate heatmaps
heats = gradcam.cal_heats(img, feats, grads)

5. Save results

for i in range(len(heats)):
    for name in heats[i]:
        # Save plain heatmap
        save_heatmap(
            save_path=f"results/{i}-{name}.jpg",
            heat=heats[i][name],
            cmap="jet",
            title="grad-cam"
        )

        # Save overlay with original image
        save_mix_heatmap(
            save_path=f"results/{i}-{name}-mix.jpg",
            heat=heats[i][name],
            ori_img=img,
            cmap="jet"
        )

Example Output

  • results/0-dinov2.encoder.layer.11.jpg: heatmap only
  • results/0-dinov2.encoder.layer.11-mix.jpg: heatmap overlay on the input image

Notes

  • Make sure the target layer you pass matches the internal structure of the model.
  • Pretrained models from torchvision, timm, and Hugging Face are supported.
  • Heatmaps are saved as .jpg files in the results/ directory.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easy_gradcam-0.0.1.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

easy_gradcam-0.0.1-py3-none-any.whl (7.7 kB view details)

Uploaded Python 3

File details

Details for the file easy_gradcam-0.0.1.tar.gz.

File metadata

  • Download URL: easy_gradcam-0.0.1.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.19

File hashes

Hashes for easy_gradcam-0.0.1.tar.gz
Algorithm Hash digest
SHA256 cd0c17161b21dc068298d5b19bdcba7bfa9fb28219b8bdc62a2a82916868e778
MD5 1952c35d5f6bcebbdb39d8f875155482
BLAKE2b-256 d28dff59f16f7d00add34e10a27b047e0e19ecb1a438471fc33c53ef85c110b6

See more details on using hashes here.

File details

Details for the file easy_gradcam-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: easy_gradcam-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.19

File hashes

Hashes for easy_gradcam-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 fb42f7802c629896d7e1014355d4d13e06420acbb6d574797c122311d26e2787
MD5 3c9f7813201518dc235cfcfa91caf7c8
BLAKE2b-256 86ee962a8620aa7017f490cfd145fe9693502a84ac4126dce4f9ae7ce372a472

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page