Gradient-based Interpretability Method for transformer feature attribution
Project description
GIM: Gradient Interaction Modifications
Installation
pip install gim-explain
# With TransformerLens support
pip install gim-explain[tlens]
Quick Start
Feature Attribution with explain()
import gim
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
input_ids = tokenizer("The capital of France is", return_tensors="pt").input_ids
attributions = gim.explain(model, input_ids, tokenizer=tokenizer)
# attributions is a tensor of shape [seq_len] with importance scores per token
Using the GIM Context Manager
For more control, use the GIM context manager directly. This is useful if you want to use GIM for circuit discovery or network pruning. You wrap the model and run your gradient-based method as usual (e.g., Edge Attribution Patching). The wrapper will automatically modify the backpropagation.
import gim
import torch.nn.functional as F
with gim.GIM(model):
logits = model(input_ids)
loss = F.cross_entropy(logits[:, -1], target)
loss.backward()
# Gradients are now modified by GIM
How It Works
GIM applies three gradient modifications during backpropagation:
- Norm Freezing: Detaches LayerNorm/RMSNorm statistics from the backward pass
- Softmax Temperature: Applies temperature scaling to softmax gradients (softer attention)
- Q/K/V Scaling: Scales gradients for query, key, and value tensors in attention
As shown in the paper, these modifications improve the quality of gradient-based feature attributions.
API Reference
gim.explain()
gim.explain(
model, # PyTorch nn.Module or TransformerLens HookedTransformer
input_ids, # Token IDs [batch, seq_len] or [seq_len]
*,
target_token_id=None, # Token to explain (default: argmax prediction)
target_position=-1, # Position to explain (default: last token)
baseline_token_id=None, # Baseline token for counterfactual
tokenizer=None, # Tokenizer to infer baseline from
freeze_norm=True, # Detach norm statistics
softmax_temperature=2.0, # Temperature for softmax backward
q_scale=0.25, # Query gradient scale
k_scale=0.25, # Key gradient scale
v_scale=0.5, # Value gradient scale
)
gim.GIM()
with gim.GIM(
model, # PyTorch nn.Module or TransformerLens HookedTransformer
*,
freeze_norm=True,
softmax_temperature=2.0,
q_scale=0.25,
k_scale=0.25,
v_scale=0.5,
):
# Your forward/backward code here
pass
Citation
@misc{edin2025gimimprovedinterpretabilitylarge,
title={GIM: Improved Interpretability for Large Language Models},
author={Joakim Edin and Róbert Csordás and Tuukka Ruotsalo and Zhengxuan Wu and Maria Maistro and Casper L. Christensen and Jing Huang and Lars Maaløe},
year={2025},
eprint={2505.17630},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2505.17630},
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gim_explain-0.1.4.tar.gz.
File metadata
- Download URL: gim_explain-0.1.4.tar.gz
- Upload date:
- Size: 13.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1102196523f6f578bf18f898312f12aa2697c2a93875d19a1634e503bbc82d35
|
|
| MD5 |
3370eb35359547efc27f2c18c205711f
|
|
| BLAKE2b-256 |
8ab8f689a80d9d60e0dd1e33e943a5adaff3b1199942eed23f43d7544b750426
|
File details
Details for the file gim_explain-0.1.4-py3-none-any.whl.
File metadata
- Download URL: gim_explain-0.1.4-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f8e409cdf886d745abeebd6f192e09dc5fc81a095f17c69bf20f6c32cf152c4c
|
|
| MD5 |
7fa57f663f73eacf8f8f4fa49de91a6f
|
|
| BLAKE2b-256 |
cae824d3e1169061fc8d31f628bd4502588f10c505301533733909cd59c70023
|