LeGrad
Project description
LeGrad
An Explainability Method for Vision Transformers via Feature Formation Sensitivity
Walid Bousselham1, Angie Boggust2, Sofian Chaybouti1, Hendrik Strobelt3,4 and Hilde Kuehne1,3
1 University of Bonn & Goethe University Frankfurt, 2 MIT CSAIL, 3 MIT-IBM Watson AI Lab, 4 IBM Research.
Vision-Language foundation models have shown remarkable performance in various zero-shot settings such as image retrieval, classification, or captioning. we propose LeGrad, an explainability method specifically designed for ViTs. We LeGrad we explore how the decision-making process of such models by leveraging their feature formation process. A by-product of understanding VL models decision-making is the ability to produce localised heatmap for any text prompt.
The following is the code for a wrapper around the OpenCLIP library to equip VL models with LeGrad.
:hammer: Installation
legrad
library can be simply installed via pip:
$ pip install legrad_torch
Demo
- Try out our web demo on HuggingFace Spaces
- Run the demo on Google Colab:
- Run
playground.py
for a usage example.
To run the gradio app locally, first install gradio and then run app.py
:
$ pip install gradio
$ python app.py
Usage
To see which pretrained models is available use the following code snippet:
import legrad
legrad.list_pretrained()
Single Image
To process an image and a text prompt use the following code snippet:
Note: the wrapper does not affect the original model, hence all the functionalities of OpenCLIP models can be used seamlessly.
import requests
from PIL import Image
import open_clip
import torch
from legrad import LeWrapper, LePreprocess
from legrad.utils import visualize
# ------- model's paramters -------
model_name = 'ViT-B-16'
pretrained = 'laion2b_s34b_b88k'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------- init model -------
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=model_name, pretrained=pretrained, device=device)
tokenizer = open_clip.get_tokenizer(model_name=model_name)
model.eval()
# ------- Equip the model with LeGrad -------
model = LeWrapper(model)
# ___ (Optional): Wrapper for Higher-Res input image ___
preprocess = LePreprocess(preprocess=preprocess, image_size=448)
# ------- init inputs: image + text -------
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = preprocess(Image.open(requests.get(url, stream=True).raw)).unsqueeze(0).to(device)
text = tokenizer(['a photo of a cat']).to(device)
# -------
text_embedding = model.encode_text(text, normalize=True)
print(image.shape)
explainability_map = model.compute_legrad_clip(image=image, text_embedding=text_embedding)
# ___ (Optional): Visualize overlay of the image + heatmap ___
visualize(heatmaps=explainability_map, image=image)
:star: Acknowledgement
This code is build as wrapper around OpenCLIP library from LAION, visit their repo for more vision-language models. This project also takes inspiration from Transformer-MM-Explainability and the timm library, please visit their repository.
:books: Citation
If you find this repository useful, please consider citing our work :pencil: and giving a star :star2: :
@article{bousselham2024legrad,
author = {Bousselham, Walid and Boggust, Angie and Chaybouti, Sofian and Strobelt, Hendrik and Kuehne, Hilde}
title = {LeGrad: An Explainability Method for Vision Transformers via Feature Formation Sensitivity},
journal = {arXiv preprint arXiv:2404.03214},
year = {2024},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file legrad_torch-1.1.tar.gz
.
File metadata
- Download URL: legrad_torch-1.1.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a43c4c026f742b8d5e9305dbecd20d2481e116bab3c7514dd0f12a3547fddac5 |
|
MD5 | 7a4f8e73662d28f4da9b9f423ebaee81 |
|
BLAKE2b-256 | e32b5e1f060e34500476a6f82b7c7fd57a12847a60216aeacf2752f4272a0026 |
File details
Details for the file legrad_torch-1.1-py3-none-any.whl
.
File metadata
- Download URL: legrad_torch-1.1-py3-none-any.whl
- Upload date:
- Size: 14.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cad57870ce1347c2cad69c1bcda2475c46ba8f0313eec52b85c4359f61ca0267 |
|
MD5 | 636f64464e035fcbd390ed3507fc7e81 |
|
BLAKE2b-256 | 99f9d552783ff4f11e54abb2e901c2cc3132549b82a7493855b79bc10a8f94ee |