A lightweight, robust utility for extracting and visualizing attention weights from PyTorch Transformer models.
Project description
Transformer-Attention-Hooker
A lightweight, robust utility for extracting and visualizing attention weights from PyTorch Transformer models.
This tool simplifies the process of debugging and analyzing Transformer internals by automatically hooking into nn.MultiheadAttention modules, handling the need_weights=True flag, and managing multiple forward passes (e.g., in generation loops or shared layers).
Features
- Automatic Hooking: Automatically detects
nn.MultiheadAttentionlayers using regex. - Force Weights: Automatically sets
need_weights=Trueduring the forward pass so you don't have to modify your model code. - Layer Reuse Support: Correctly captures attention weights even if a layer is called multiple times (e.g., in a loop or with shared weights).
- Cross-Attention Support: Works with both square self-attention and rectangular cross-attention matrices.
- Visualization Tools: Includes a built-in visualizer to plot attention heads as heatmaps.
Requirements
- Python 3.6+
- PyTorch
- Matplotlib
pip install torch matplotlib
Quick Start
1. Extracting Attention Weights
Wrap your model with TransformerAttentionHooker before running the forward pass.
import torch
import torch.nn as nn
from transformer_attention_hooker import TransformerAttentionHooker
# 1. Define or load your model
model = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True),
num_layers=2
)
# 2. Setup the hooker
# By default, it hooks layers ending with 'self_attn'
hooker = TransformerAttentionHooker(model, layer_regex=r"self_attn$").setup()
# 3. Run a forward pass
x = torch.randn(1, 10, 32) # (Batch, Seq, Feature)
output = model(x)
# 4. Access the captured attention weights
# hooker.values is a dict: {layer_name: [tensor_call_1, tensor_call_2, ...]}
print("Captured layers:", list(hooker.values.keys()))
for name, attn_list in hooker.values.items():
print(f"Layer: {name}")
# Get the tensor from the first call
attn_tensor = attn_list[0]
print(f" Shape: {attn_tensor.shape}") # (Batch, Heads, Seq, Seq)
# 5. Cleanup
hooker.remove_hooks()
2. Visualizing Attention
Use the included plot_attention_grid function to generate heatmaps for all heads in a layer.
from transformer_attention_hooker import plot_attention_grid
# Assuming 'attn_tensor' is captured from the example above
layer_name = "layers.0.self_attn"
attn_tensor = hooker.values[layer_name][0]
plot_attention_grid(
attn_tensor,
tokens=[f"Token_{i}" for i in range(10)], # Optional: Add labels
layer_name=layer_name,
save_path=f"plots/{layer_name}.png"
)
Advanced Usage
Custom Layer Selection
If your model names its attention layers differently (e.g., attn1, cross_attention), you can pass a custom regex pattern.
# Hook all layers containing "attn"
hooker = TransformerAttentionHooker(model, layer_regex=r".*attn.*").setup()
Handling Loops (Generation / Shared Layers)
If a layer is used multiple times during a forward pass (common in recurrent-style generation or weight sharing), hooker.values[layer_name] will contain a list of tensors, one for each call.
# Example: A layer called 3 times
output = model(x)
attn_calls = hooker.values['my_layer']
print(len(attn_calls)) # 3
print(attn_calls[0].shape) # Attention from 1st pass
print(attn_calls[1].shape) # Attention from 2nd pass
Project Structure
src/attention_hooker.py: Core hooking logic.src/visualizer.py: Matplotlib plotting utilities.demo_viz.py: Runnable demo script.test_edge_cases.py: Tests ensuring robustness for loops and cross-attention.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Copyright (c) 2025 Donghwee Yoon
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformer_attention_hooker-0.1.0a1.tar.gz.
File metadata
- Download URL: transformer_attention_hooker-0.1.0a1.tar.gz
- Upload date:
- Size: 6.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
399f1dafe4cfbd780cbd3b1375661973c89ccffbecb127de9588af410f17892f
|
|
| MD5 |
ea0e6e52c7650e494359552eaec94bb8
|
|
| BLAKE2b-256 |
449701c6a0213529eff63653503d3bf7b7e9cba78e2afcb2508ca8ed27b1eac5
|
Provenance
The following attestation bundles were made for transformer_attention_hooker-0.1.0a1.tar.gz:
Publisher:
python-publish.yml on DonghweeYoon/Transformer-Attention-Hooker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
transformer_attention_hooker-0.1.0a1.tar.gz -
Subject digest:
399f1dafe4cfbd780cbd3b1375661973c89ccffbecb127de9588af410f17892f - Sigstore transparency entry: 730369193
- Sigstore integration time:
-
Permalink:
DonghweeYoon/Transformer-Attention-Hooker@a3722712b80684b1782921297bcf7b75f8c492ae -
Branch / Tag:
refs/tags/alpha.1 - Owner: https://github.com/DonghweeYoon
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@a3722712b80684b1782921297bcf7b75f8c492ae -
Trigger Event:
release
-
Statement type:
File details
Details for the file transformer_attention_hooker-0.1.0a1-py3-none-any.whl.
File metadata
- Download URL: transformer_attention_hooker-0.1.0a1-py3-none-any.whl
- Upload date:
- Size: 7.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4950dcd01219a6a2fb5045428017e7cbb270970a616a4fac96723ad4e08c842d
|
|
| MD5 |
def3087ebeb771d768633e91b788c252
|
|
| BLAKE2b-256 |
8482fd05c81eb6bcf2f7a0c7ceb073091add0357ec7c3425028f116a35c0123a
|
Provenance
The following attestation bundles were made for transformer_attention_hooker-0.1.0a1-py3-none-any.whl:
Publisher:
python-publish.yml on DonghweeYoon/Transformer-Attention-Hooker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
transformer_attention_hooker-0.1.0a1-py3-none-any.whl -
Subject digest:
4950dcd01219a6a2fb5045428017e7cbb270970a616a4fac96723ad4e08c842d - Sigstore transparency entry: 730369195
- Sigstore integration time:
-
Permalink:
DonghweeYoon/Transformer-Attention-Hooker@a3722712b80684b1782921297bcf7b75f8c492ae -
Branch / Tag:
refs/tags/alpha.1 - Owner: https://github.com/DonghweeYoon
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@a3722712b80684b1782921297bcf7b75f8c492ae -
Trigger Event:
release
-
Statement type: