A robust, plug-and-play Attention Rollout explainer for Vision Transformers (timm & Hugging Face).
Project description
Fast ViT Rollout 🚀
A robust, plug-and-play Attention Rollout extractor for Vision Transformers. Works seamlessly out of the box with timm and Hugging Face backbones.
Why use this?
Most existing Attention Rollout scripts break on modern architectures. This package natively handles:
- Flash Attention (SDPA): Bypasses PyTorch 2.0+ fused attention issues.
- Register Tokens: Flawlessly parses DINOv2 and DeiT models without shape mismatch errors.
- Auto-Detection: Automatically hooks the correct layers without manual indexing.
- Native Heatmaps: Generates OpenCV-based overlays with built-in intensity colorbars.
Installation
pip install fast-vit-rollout
Quick Start
import torch
import timm
import cv2
import urllib.request
import numpy as np
from PIL import Image
from torchvision import transforms
# 1. This proves the package installed correctly!
from fast_vit_rollout import ViTAttentionRollout
print("Package imported successfully! Testing model...")
# 2. Setup DINOv2
import timm.layers
timm.layers.set_fused_attn(False)
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# 3. Get Image
img_size = 224 #Setup Image size according to model's input size defined.
filename="image.png" # setup the image directory
img_pil = Image.open(filename).convert('RGB')
original_rgb = cv2.resize(np.array(img_pil), (img_size, img_size))
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(img_pil).unsqueeze(0)
# 4. Run the newly installed package
rollout = ViTAttentionRollout(model, discard_ratio=0.9)
matrix, heatmap = rollout(input_tensor, original_image=original_rgb, return_vis=True)
cv2.imwrite("final_test.jpg", cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))
print("Test passed! Heatmap saved as final_test.jpg")
###Github 🚀 This is an open-source codebase. We welcome contributions from the community!
If you encounter any bugs, have suggestions for improvements, or would like to add new features, feel free to open an issue or submit a pull request.
Contribute here: https://github.com/Souradeep2233/fast-vit-rollout
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fast_vit_rollout-0.1.1.tar.gz.
File metadata
- Download URL: fast_vit_rollout-0.1.1.tar.gz
- Upload date:
- Size: 6.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b83b96724d1afa3b0b7906dd9edb4b9c4e174e76ecff9bafed44ef681b9b1321
|
|
| MD5 |
98a11c15cea16d0ebe3ec1fd83256d1a
|
|
| BLAKE2b-256 |
58f02d3b834a8a510d082552c0c0da2342c1fd5f0c0b1f58829e96f45eb19ba8
|
File details
Details for the file fast_vit_rollout-0.1.1-py3-none-any.whl.
File metadata
- Download URL: fast_vit_rollout-0.1.1-py3-none-any.whl
- Upload date:
- Size: 6.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5b4f0970f24cf979c34bea0896575725aa988adc233c2d86271fb50e57150e48
|
|
| MD5 |
a65acd70b191f808d7bbfdb082b47157
|
|
| BLAKE2b-256 |
b6b83cfd3ccbcde372d8662c41ebcef32a151a9e2856834701baca6314d00119
|