Skip to main content

A lightweight, robust utility for extracting and visualizing attention weights from PyTorch Transformer models.

Project description

Transformer-Attention-Hooker

A lightweight, robust utility for extracting and visualizing attention weights from PyTorch Transformer models.

This tool simplifies the process of debugging and analyzing Transformer internals by automatically hooking into nn.MultiheadAttention modules, handling the need_weights=True flag, and managing multiple forward passes (e.g., in generation loops or shared layers).

Features

  • Automatic Hooking: Automatically detects nn.MultiheadAttention layers using regex.
  • Force Weights: Automatically sets need_weights=True during the forward pass so you don't have to modify your model code.
  • Layer Reuse Support: Correctly captures attention weights even if a layer is called multiple times (e.g., in a loop or with shared weights).
  • Cross-Attention Support: Works with both square self-attention and rectangular cross-attention matrices.
  • Visualization Tools: Includes a built-in visualizer to plot attention heads as heatmaps.

Requirements

  • Python 3.6+
  • PyTorch
  • Matplotlib
pip install torch matplotlib

Quick Start

1. Extracting Attention Weights

Wrap your model with TransformerAttentionHooker before running the forward pass.

import torch
import torch.nn as nn
from transformer_attention_hooker import TransformerAttentionHooker

# 1. Define or load your model
model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True),
    num_layers=2
)

# 2. Setup the hooker
# By default, it hooks layers ending with 'self_attn'
hooker = TransformerAttentionHooker(model, layer_regex=r"self_attn$").setup()

# 3. Run a forward pass
x = torch.randn(1, 10, 32) # (Batch, Seq, Feature)
output = model(x)

# 4. Access the captured attention weights
# hooker.values is a dict: {layer_name: [tensor_call_1, tensor_call_2, ...]}
print("Captured layers:", list(hooker.values.keys()))

for name, attn_list in hooker.values.items():
    print(f"Layer: {name}")
    # Get the tensor from the first call
    attn_tensor = attn_list[0]
    print(f"  Shape: {attn_tensor.shape}") # (Batch, Heads, Seq, Seq)

# 5. Cleanup
hooker.remove_hooks()

2. Visualizing Attention

Use the included plot_attention_grid function to generate heatmaps for all heads in a layer.

from transformer_attention_hooker import plot_attention_grid

# Assuming 'attn_tensor' is captured from the example above
layer_name = "layers.0.self_attn"
attn_tensor = hooker.values[layer_name][0]

plot_attention_grid(
    attn_tensor,
    tokens=[f"Token_{i}" for i in range(10)], # Optional: Add labels
    layer_name=layer_name,
    save_path=f"plots/{layer_name}.png"
)

Advanced Usage

Custom Layer Selection

If your model names its attention layers differently (e.g., attn1, cross_attention), you can pass a custom regex pattern.

# Hook all layers containing "attn"
hooker = TransformerAttentionHooker(model, layer_regex=r".*attn.*").setup()

Handling Loops (Generation / Shared Layers)

If a layer is used multiple times during a forward pass (common in recurrent-style generation or weight sharing), hooker.values[layer_name] will contain a list of tensors, one for each call.

# Example: A layer called 3 times
output = model(x)

attn_calls = hooker.values['my_layer']
print(len(attn_calls)) # 3
print(attn_calls[0].shape) # Attention from 1st pass
print(attn_calls[1].shape) # Attention from 2nd pass

Project Structure

  • src/attention_hooker.py: Core hooking logic.
  • src/visualizer.py: Matplotlib plotting utilities.
  • demo_viz.py: Runnable demo script.
  • test_edge_cases.py: Tests ensuring robustness for loops and cross-attention.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Copyright (c) 2025 Donghwee Yoon

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformer_attention_hooker-0.1.1a1.tar.gz (6.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file transformer_attention_hooker-0.1.1a1.tar.gz.

File metadata

File hashes

Hashes for transformer_attention_hooker-0.1.1a1.tar.gz
Algorithm Hash digest
SHA256 9fe577f421a3191b788f987f30365b2db068066d78ce9d40c05147ce4c3442bb
MD5 118f2c00423449a706adf369dc5304aa
BLAKE2b-256 8a997f6e201daaa575cce529a36bda20bf263313cee32b6c1fbd1237aa94a9ee

See more details on using hashes here.

Provenance

The following attestation bundles were made for transformer_attention_hooker-0.1.1a1.tar.gz:

Publisher: python-publish.yml on DonghweeYoon/Transformer-Attention-Hooker

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file transformer_attention_hooker-0.1.1a1-py3-none-any.whl.

File metadata

File hashes

Hashes for transformer_attention_hooker-0.1.1a1-py3-none-any.whl
Algorithm Hash digest
SHA256 1763d08847ae5de64757298bb17e11c3049f6f003705a1b62ccbebb87b308474
MD5 008b69968b89c5212ff637a39b0cf95a
BLAKE2b-256 2357ec663cac3b6ca07cd60c14c70013292cbb0ef7b1ea569f78aab56207c03e

See more details on using hashes here.

Provenance

The following attestation bundles were made for transformer_attention_hooker-0.1.1a1-py3-none-any.whl:

Publisher: python-publish.yml on DonghweeYoon/Transformer-Attention-Hooker

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page