Text-to-Image Attention Visualization for Vision-Language Models.
Project description
Vision–Language Saliency Extraction
This library provides a simple, model-agnostic interface to compute and visualize text-to-image saliency maps, extending classic methods originally developed for Vision Transformers (ViTs) to modern vision-language architectures. Compatible with any Hugging Face Image-Text-to-Text model, this library makes it easy to interpret vision-language model output. Modular and extensible, novel saliency techniques can be easily integrated.
Table of Contents
Installation
This library is available through PyPI and can be installed using pip:
pip install vl-saliency
Features
See the quickstart notebook for a complete example of how to use the saliency extractor with a Gemma3 vision-language model.
Using SaliencyExtractor objects, you can easily compute and visualize saliency maps for any Hugging Face Image-Text-to-Text model.
from vl_saliency import SaliencyExtractor
# Initialize the model and input prompt
model = AutoModel.from_pretrained("model_name") # Replace with your model name
processor = AutoProcessor.from_pretrained("model_name") # Replace with your processor name
image = PIL.Image.open("path_to_image.jpg") # Load your image
inputs = processor(text="Your prompt", images=image, return_tensors="pt")
# Initialize the saliency extractor
extractor = SaliencyExtractor(model, processor)
# Generate response
with torch.inference_mode():
generated_ids = model.generate(**inputs, do_sample=True, max_new_tokens=200)
# Compute attention and gradients
trace = extractor.capture(**inputs, generated_ids=generated_ids)
# Compute the saliency map from a specific token to the image
saliency_map = trace.map(token=200) # Change token_index as needed
# Aggregate the saliency map's layers and heads
saliency_map = saliency_map.agg(layer_reduce="mean", head_reduce="mean")
# Visualize the saliency map
saliency_map.plot(image, title="Saliency Map")
Attention and Gradients
You can compute saliency maps based on either attention weights or gradients. By default, SaliencyExtractor stores both attention and gradient information during the forward and backward passes. If you only need one of these, you can disable the other to save memory and computation time.
# Initialize the saliency extractor to store only gradients
extractor = SaliencyExtractor(model, processor, store_attns=False) # Similarly, use store_grads=False to store only attention
saliency_map = extractor.capture(**inputs, generated_ids=generated_ids).map(token=200)
saliency_map.agg().plot(image, title="Gradient-based Saliency Map")
Some more advanced saliency methods may require access to both attention weights and gradients. You can apply such methods directly to traces using the mode argument in the map method, returning a new saliency map.
from vl_saliency.lib import gradcam
# Compute Grad-CAM saliency map
saliency_map = trace.map(token=200, mode=gradcam)
saliency_map.agg().plot(image, title="Grad-CAM Saliency Map")
To define your own such composite saliency methods, see the Defining Custom Transforms section below.
Transforms
The library includes several built-in Transform objects to process saliency maps. Saliency maps are immutable, so applying a transform returns a new saliency map. You can chain transforms using the >> operator, or call the apply method.
from vl_saliency import transforms as T
# Example: Normalize and plot a saliency map
saliency_map = saliency_map >> T.normalize()
saliency_map.agg().plot(image, title="Normalized Saliency Map")
# Example: Binarize a saliency map, setting values below the mean to zero
saliency_map = saliency_map.apply(T.Binarize(threshold="mean"))
saliency_map.agg().plot(image, title="Binarized Saliency Map")
# Example: Apply the sigmoid function to a saliency map, then aggregate across heads and layers
saliency_map = saliency_map >> T.Sigmoid() >> T.Aggregate(layer_reduce="mean", head_reduce="mean")
saliency_map.plot(image, title="Sigmoid Saliency Map")
Pipeline API
For more complex visualization workflows, you can combine multiple Transform objects into a reuseable Pipeline, allowing you to apply the same sequence of transforms to multiple saliency maps.
from vl_saliency import transforms as T
pipe = (
T.abs() >>
T.normalize() >>
T.Aggregate(layer_reduce="mean", head_reduce="sum")
)
# Apply the pipeline to a saliency map
saliency_map >>= pipe
saliency_map.plot(image, title="Pipeline Processed Saliency Map")
# Alternatively, you can directly create a pipeline using the constructor
pipe = T.Pipeline(
T.abs(),
T.normalize(),
T.Aggregate(layer_reduce="mean", head_reduce="sum")
)
saliency_map = saliency_map.apply(pipe).plot(image, title="Pipeline Processed Saliency Map")
Defining Custom Transforms
You can define your own custom transforms by subclassing the Chainable interface. Note that Chainable classes must implement the __call__ method with exactly the following signature:
from vl_saliency import SaliencyMap
from vl_saliency.transforms import Chainable
class MyTransform(Chainable):
def __call__(self, saliency_map: SaliencyMap) -> SaliencyMap:
# Custom transformation logic
return saliency_map
Alternatively, you can use the @chainable decorator to create simple transforms without subclassing. The decorated function must also adhere to the same signature:
from vl_saliency import SaliencyMap
from vl_saliency.transforms import chainable
@chainable
def my_custom_transform(saliency_map: SaliencyMap) -> SaliencyMap:
# Custom transformation logic
return saliency_map
For methods that require both attention weights and gradients, you can define a transform that processes both and returns a new saliency map. Such transforms are defined under the protocol TraceTransform and can be applied directly to Trace objects using the map method. They must implement the following signature:
from vl_saliency import SaliencyMap
def my_trace_transform(attn: SaliencyMap, grad: SaliencyMap) -> SaliencyMap:
# Custom transformation logic using both attention and gradients
return saliency_map
class MyTraceTransform:
def __call__(self, attn: SaliencyMap, grad: SaliencyMap) -> SaliencyMap:
# Custom transformation logic using both attention and gradients
return saliency_map
Note that TraceTransform objects aren't chainable like regular transforms, since they operate on two inputs.
Contributing
Contributions are welcome! Open an issue to discuss ideas or submit a PR directly.
Getting Started
-
Clone the repository and install the required dependencies.
git clone https://github.com/alexander-brady/vl-saliency cd vl-saliency
-
Create a virtual environment and activate it.
python -m venv .venv source .venv/bin/activate # On Windows use `.venv\Scripts\activate`
-
Install the development dependencies.
pip install -e .[dev]
Guidelines
Before submitting a pull request, ensure:
ruff check . --fix && ruff format . # Lint & format
pytest # Run tests
mypy . # Type check
License
This project is licensed under the MIT License – see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file vl_saliency-1.1.6.tar.gz.
File metadata
- Download URL: vl_saliency-1.1.6.tar.gz
- Upload date:
- Size: 3.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ec7fea50e5d2e254312011438e532d9e172206cf2e40d5d750cd143bde2ff172
|
|
| MD5 |
e08e964e4ae5f711a6a90f967002126f
|
|
| BLAKE2b-256 |
ba9218b4b80ece43646106de11ffcdf9d5826a4fd691747a9f8988feba6adbaf
|
File details
Details for the file vl_saliency-1.1.6-py3-none-any.whl.
File metadata
- Download URL: vl_saliency-1.1.6-py3-none-any.whl
- Upload date:
- Size: 28.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a4494f5c9a305218639a95cfd2cab7c6a332b4acc0d4b8759c587737bb9a22d0
|
|
| MD5 |
f3b686d5e2dde15ae4bd7f83b8e6a3c1
|
|
| BLAKE2b-256 |
0a651bd69168f98f61ab34bbe77791d7ae75d9bd848c77fb2f15e642969d7c38
|