Skip to main content

Gradient-based Interpretability Method for transformer feature attribution

Project description

GIM Logo

GIM: Gradient Interaction Modifications

Installation

pip install gim-explain

# With TransformerLens support
pip install gim-explain[tlens]

Quick Start

Feature Attribution with explain()

import gim
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

input_ids = tokenizer("The capital of France is", return_tensors="pt").input_ids
attributions = gim.explain(model, input_ids, tokenizer=tokenizer)

# attributions is a tensor of shape [seq_len] with importance scores per token

Using the GIM Context Manager

For more control, use the GIM context manager directly. This is useful if you want to use GIM for circuit discovery or network pruning. You wrap the model and run your gradient-based method as usual (e.g., Edge Attribution Patching). The wrapper will automatically modify the backpropagation.

import gim
import torch.nn.functional as F

with gim.GIM(model):
    logits = model(input_ids)
    loss = F.cross_entropy(logits[:, -1], target)
    loss.backward()
    # Gradients are now modified by GIM

How It Works

GIM applies three gradient modifications during backpropagation:

  1. Norm Freezing: Detaches LayerNorm/RMSNorm statistics from the backward pass
  2. Softmax Temperature: Applies temperature scaling to softmax gradients (softer attention)
  3. Q/K/V Scaling: Scales gradients for query, key, and value tensors in attention

As shown in the paper, these modifications improve the quality of gradient-based feature attributions.

API Reference

gim.explain()

gim.explain(
    model,                          # PyTorch nn.Module or TransformerLens HookedTransformer
    input_ids,                      # Token IDs [batch, seq_len] or [seq_len]
    *,
    target_token_id=None,           # Token to explain (default: argmax prediction)
    target_position=-1,             # Position to explain (default: last token)
    baseline_token_id=None,         # Baseline token for counterfactual
    tokenizer=None,                 # Tokenizer to infer baseline from
    freeze_norm=True,               # Detach norm statistics
    softmax_temperature=2.0,        # Temperature for softmax backward
    q_scale=0.25,                   # Query gradient scale
    k_scale=0.25,                   # Key gradient scale
    v_scale=0.5,                    # Value gradient scale
)

gim.GIM()

with gim.GIM(
    model,                          # PyTorch nn.Module or TransformerLens HookedTransformer
    *,
    freeze_norm=True,
    softmax_temperature=2.0,
    q_scale=0.25,
    k_scale=0.25,
    v_scale=0.5,
):
    # Your forward/backward code here
    pass

Citation

@misc{edin2025gimimprovedinterpretabilitylarge,
      title={GIM: Improved Interpretability for Large Language Models}, 
      author={Joakim Edin and Róbert Csordás and Tuukka Ruotsalo and Zhengxuan Wu and Maria Maistro and Casper L. Christensen and Jing Huang and Lars Maaløe},
      year={2025},
      eprint={2505.17630},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2505.17630}, 
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gim_explain-0.1.4.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gim_explain-0.1.4-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file gim_explain-0.1.4.tar.gz.

File metadata

  • Download URL: gim_explain-0.1.4.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.11

File hashes

Hashes for gim_explain-0.1.4.tar.gz
Algorithm Hash digest
SHA256 1102196523f6f578bf18f898312f12aa2697c2a93875d19a1634e503bbc82d35
MD5 3370eb35359547efc27f2c18c205711f
BLAKE2b-256 8ab8f689a80d9d60e0dd1e33e943a5adaff3b1199942eed23f43d7544b750426

See more details on using hashes here.

File details

Details for the file gim_explain-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: gim_explain-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.11

File hashes

Hashes for gim_explain-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f8e409cdf886d745abeebd6f192e09dc5fc81a095f17c69bf20f6c32cf152c4c
MD5 7fa57f663f73eacf8f8f4fa49de91a6f
BLAKE2b-256 cae824d3e1169061fc8d31f628bd4502588f10c505301533733909cd59c70023

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page