Skip to main content

A robust, plug-and-play Attention Rollout explainer for Vision Transformers (timm & Hugging Face).

Project description

Fast ViT Rollout 🚀

A robust, plug-and-play Attention Rollout extractor for Vision Transformers. Works seamlessly out of the box with timm and Hugging Face backbones.

Why use this?

Most existing Attention Rollout scripts break on modern architectures. This package natively handles:

  • Flash Attention (SDPA): Bypasses PyTorch 2.0+ fused attention issues.
  • Register Tokens: Flawlessly parses DINOv2 and DeiT models without shape mismatch errors.
  • Auto-Detection: Automatically hooks the correct layers without manual indexing.
  • Native Heatmaps: Generates OpenCV-based overlays with built-in intensity colorbars.

Installation

pip install fast-vit-rollout

Quick Start

import torch
import timm
import cv2
import urllib.request
import numpy as np
from PIL import Image
from torchvision import transforms

# 1. This proves the package installed correctly!
from fast_vit_rollout import ViTAttentionRollout

print("Package imported successfully! Testing model...")

# 2. Setup DINOv2
import timm.layers
timm.layers.set_fused_attn(False) 
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# 3. Get Image
img_size = 224 #Setup Image size according to model's input size defined.
filename="image.png" # setup the image directory
img_pil = Image.open(filename).convert('RGB')
original_rgb = cv2.resize(np.array(img_pil), (img_size, img_size))

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(img_pil).unsqueeze(0)

# 4. Run the newly installed package
rollout = ViTAttentionRollout(model, discard_ratio=0.9)
matrix, heatmap = rollout(input_tensor, original_image=original_rgb, return_vis=True)

cv2.imwrite("final_test.jpg", cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))
print("Test passed! Heatmap saved as final_test.jpg")

###Github 🚀 This is an open-source codebase. We welcome contributions from the community!

If you encounter any bugs, have suggestions for improvements, or would like to add new features, feel free to open an issue or submit a pull request.

Contribute here: https://github.com/Souradeep2233/fast-vit-rollout

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_vit_rollout-0.1.1.tar.gz (6.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fast_vit_rollout-0.1.1-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file fast_vit_rollout-0.1.1.tar.gz.

File metadata

  • Download URL: fast_vit_rollout-0.1.1.tar.gz
  • Upload date:
  • Size: 6.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for fast_vit_rollout-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b83b96724d1afa3b0b7906dd9edb4b9c4e174e76ecff9bafed44ef681b9b1321
MD5 98a11c15cea16d0ebe3ec1fd83256d1a
BLAKE2b-256 58f02d3b834a8a510d082552c0c0da2342c1fd5f0c0b1f58829e96f45eb19ba8

See more details on using hashes here.

File details

Details for the file fast_vit_rollout-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for fast_vit_rollout-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5b4f0970f24cf979c34bea0896575725aa988adc233c2d86271fb50e57150e48
MD5 a65acd70b191f808d7bbfdb082b47157
BLAKE2b-256 b6b83cfd3ccbcde372d8662c41ebcef32a151a9e2856834701baca6314d00119

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page