Skip to main content

Mech interp for any HuggingFace model.

Project description

IIIII         tt                          KK  KK iii tt
 III  nn nnn  tt      eee  rr rr  pp pp   KK KK      tt
 III  nnn  nn tttt  ee   e rrr  r ppp  pp KKKK   iii tttt
 III  nn   nn tt    eeeee  rr     pppppp  KK KK  iii tt
IIIII nn   nn  tttt  eeeee rr     pp      KK  KK iii  tttt
                                  pp

Mech interp for any HuggingFace model.

PyPI version License: MIT Python 3.10+


Why InterpKit?

Mechanistic interpretability tooling today is fragmented. Each library supports a narrow set of architectures, and moving to a different model family usually means rewriting hook code from scratch.

InterpKit provides a single, consistent interface for mech interp operations across any HuggingFace model — transformers, SSMs, vision models, and more — with zero annotation required.


Install

pip install interpkit

# For linear probe support:
pip install interpkit[probe]

Or install from source for development:

git clone https://github.com/davidezani/InterpKit.git
cd InterpKit
pip install -e ".[dev]"

Quickstart

import interpkit

model = interpkit.load("gpt2")

# One-command model overview — runs DLA, logit lens, attention, attribution
# and surfaces the most interesting findings automatically
model.scan("The capital of France is")

# Or run individual operations:
model.inspect()                                        # module tree
model.dla("The capital of France is")                  # direct logit attribution
model.trace("...Paris...", "...Rome...", top_k=20)     # causal tracing
model.lens("The capital of France is")                 # logit lens (all positions)
model.attribute("The capital of France is")            # gradient saliency
model.decompose("The capital of France is")            # residual stream decomposition

Works the same on any HF architecture:

model = interpkit.load("state-spaces/mamba-370m")
model = interpkit.load("google/vit-base-patch16-224")
model = interpkit.load("bert-base-uncased")

Operations

Operation What it does Works on
scan One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings LMs
dla Direct Logit Attribution — decompose output logits by head and MLP contribution LMs
inspect Module tree with types, param counts, shapes Any model
patch Activation patching at a module, head, or position Any model
trace Causal tracing — module-level or position-aware (Meng et al.) heatmap Any model
attribute Gradient saliency over inputs (returns scores programmatically) Any model
lens Logit lens — project activations to vocabulary at all positions LMs (auto-detected)
activations Extract raw activation tensors at any module Any model
head_activations Decompose attention output into per-head contributions Transformers
ablate Zero/mean ablate a component and measure effect Any model
attention Visualize attention patterns per layer/head Transformers
steer Extract and apply steering vectors LMs
probe Linear probe on activations Any model
diff Compare activations between two models Any model
features SAE feature decomposition Any model
decompose Residual stream decomposition — per-component norms Transformers
ov_scores OV circuit analysis — W_OV matrix per head Transformers
qk_scores QK circuit analysis — W_QK matrix per head Transformers
composition Q/K/V composition scores between heads in two layers Transformers
find_circuit Automated circuit discovery via iterative ablation Transformers
batch Run any operation over a dataset with result aggregation Any model

Scan — One-Command Model Overview

The fastest way to understand what a model is doing on an input. Runs DLA, logit lens, attention analysis, and gradient attribution, then surfaces the most interesting findings in a ranked summary:

model.scan("The capital of France is")
# Output:
#   Predictions: "the" (8.5%), "now" (4.8%), "a" (4.6%)
#   Key Findings (ranked by significance):
#     1. Top contributor to "the": L11.attn (+204.701)
#     2. Top attention head: L11.H0 (+149.850)
#     3. Most salient input token: "is" (score 12.435)
#     4. Answer "the" first appears at layer 9/12

model.scan("The capital of France is", save="scan")  # exports scan_dla.png, scan_lens.png, etc.

Direct Logit Attribution (DLA)

Answers the fundamental question: why does the model predict this token? Decomposes the output logit by component (attention block + MLP per layer) and by individual attention head:

result = model.dla("The capital of France is")
# result["contributions"]      — per-component logit contributions, sorted
# result["head_contributions"] — per-head breakdown
# result["target_token"]       — the token being attributed

# Attribute a specific token
model.dla("The capital of France is", token="Paris")

# Save a bar chart
model.dla("The capital of France is", save="dla.png")

Causal Tracing

# Module-level tracing (default) — rank modules by causal effect
model.trace("...Paris...", "...Rome...", top_k=20)

# Position-aware tracing (Meng et al. 2022) — (layer x position) heatmap
model.trace("...Paris...", "...Rome...", mode="position", save="trace.png")

Logit Lens

Now analyses all token positions by default, producing the classic (layers x positions) heatmap:

model.lens("The capital of France is")                      # all positions
model.lens("The capital of France is", position=-1)         # last position only
model.lens("The capital of France is", save="lens.png")     # 2D heatmap export

Attribution

Returns scores programmatically (no longer just prints):

result = model.attribute("The capital of France is")
result["tokens"]   # ["The", "capital", "of", "France", "is"]
result["scores"]   # [8.88, 11.15, 7.24, 7.37, 12.43]
result["target"]   # 262

Activation Patching

Supports module-level, head-level, and position-level patching:

# Module-level (original)
model.patch(clean, corrupted, at="transformer.h.8.mlp")

# Head-level — patch only attention head 3
model.patch(clean, corrupted, at="transformer.h.8", head=3)

# Position-level — patch only positions 3 and 4
model.patch(clean, corrupted, at="transformer.h.8", positions=[3, 4])

# Combined — patch head 3 at positions 3 and 4
model.patch(clean, corrupted, at="transformer.h.8", head=3, positions=[3, 4])

Head-Level Activations

Decompose an attention module's output into per-head contributions, optionally projected through W_O into residual-stream space:

result = model.head_activations("The capital of France is", at="transformer.h.8")
result["head_acts"]   # tensor (num_heads, batch, seq, d_model)
result["num_heads"]   # 12
result["head_dim"]    # 64

Activations, Ablation, Attention

# Extract raw activations
act = model.activations("The capital of France is", at="transformer.h.8.mlp")
acts = model.activations("...", at=["transformer.h.0", "transformer.h.8.mlp"])

# Ablation — zero or mean
result = model.ablate("The capital of France is", at="transformer.h.8.mlp")
result = model.ablate("...", at="transformer.h.8.mlp", method="mean")

# Attention patterns
model.attention("The capital of France is")                   # all layers
model.attention("The capital of France is", layer=8, head=3)  # single head

Residual Stream Decomposition

Break down the residual stream at any position into contributions from each attention block and MLP:

result = model.decompose("The capital of France is")
# result["components"] — list of {"name": "L8.attn", "type": "attn", "norm": 8.94, ...}
# result["residual"]   — final residual stream vector

OV / QK Circuit Analysis

Analyse the effective weight matrices of attention heads:

# OV circuit: what does each head write to the residual stream?
model.ov_scores(layer=8)
# Per-head Frobenius norm, top singular values, approximate rank of W_OV

# QK circuit: what does each head attend to?
model.qk_scores(layer=8)

# Composition: how much does head j in layer 4 compose with head i in layer 8?
model.composition(src_layer=4, dst_layer=8, comp_type="q")  # Q-composition
model.composition(src_layer=4, dst_layer=8, comp_type="k")  # K-composition
model.composition(src_layer=4, dst_layer=8, comp_type="v")  # V-composition

Circuit Discovery

Automatically find the minimal set of components that explain a behaviour:

circuit = model.find_circuit(
    "The Eiffel Tower is in Paris",
    "The Eiffel Tower is in Rome",
    threshold=0.05,
)
# circuit["circuit"]       — components in the circuit, sorted by effect
# circuit["excluded"]      — components not in the circuit
# circuit["verification"]  — faithfulness check (how much output is preserved
#                            when all non-circuit components are ablated)

Batch / Dataset Operations

Run any operation over a dataset of examples with automatic result aggregation:

# Generic batch runner
results = model.batch("trace", [
    {"clean": "...Paris...", "corrupted": "...Rome..."},
    {"clean": "...Berlin...", "corrupted": "...Madrid..."},
], op_kwargs={"top_k": 10})
# results["summary"]["ranked_modules"] — modules ranked by mean effect across examples

# Convenience: trace over a dataset
results = model.trace_batch(dataset, clean_col="clean", corrupted_col="corrupted")

# Convenience: DLA over a list of texts
results = model.dla_batch(["The capital of France is", "The CEO of Apple is"])
# results["summary"]["ranked_components"] — components ranked by mean contribution

Steering

vector = model.steer_vector("Love", "Hate", at="transformer.h.8")
model.steer("The weather today is", vector=vector, at="transformer.h.8", scale=2.0)

Linear Probe

result = model.probe(
    texts=["The cat sat", "The dog ran", "A bird flew", "A fish swam"],
    labels=[0, 0, 1, 1],
    at="transformer.h.8",
)
print(result["accuracy"])

Model Diff

base = interpkit.load("gpt2")
finetuned = interpkit.load("my-finetuned-gpt2")
interpkit.diff(base, finetuned, "The capital of France is")

SAE Features

Decompose activations into interpretable features using pre-trained Sparse Autoencoders from HuggingFace:

model.features(
    "The capital of France is",
    at="transformer.h.8",
    sae="jbloom/GPT2-Small-SAEs-Reformatted",
)

No SAELens dependency — weights are loaded directly via safetensors.

Activation Cache

Avoid redundant forward passes when exploring the same input with multiple operations:

model.cache("The capital of France is")  # one forward pass, cache all layers
model.activations("The capital of France is", at="transformer.h.8.mlp")  # instant
model.activations("The capital of France is", at="transformer.h.0.mlp")  # instant

model.clear_cache()  # free memory

Visualizations

Pass save="path.png" to export a static matplotlib figure, or html="path.html" for an interactive visualization:

model.attention("hello world", layer=0, head=0, save="attention.png")
model.trace("...Paris...", "...Rome...", save="trace.png")
model.trace("...Paris...", "...Rome...", mode="position", save="position_trace.png")
model.lens("The capital of France is", save="lens.png")
model.steer("The weather is", vector=vector, at="transformer.h.8", save="steer.png")
model.attribute("The capital of France is", save="attribution.png")
model.dla("The capital of France is", save="dla.png")
model.scan("The capital of France is", save="scan")
interpkit.diff(base, finetuned, "...", save="diff.png")

# Interactive HTML — self-contained files with hover tooltips, filters, and sliders
model.attention("hello world", html="attention.html")
model.trace("...Paris...", "...Rome...", html="trace.html")
model.attribute("The capital of France is", html="attribution.html")

CLI

interpkit inspect gpt2
interpkit scan gpt2 "The capital of France is"
interpkit dla gpt2 "The capital of France is"
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --top-k 20
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --mode position --save trace.png
interpkit lens gpt2 "The capital of France is"
interpkit lens gpt2 "The capital of France is" --position -1
interpkit attention gpt2 "The capital of France is" --layer 8 --save attention.png
interpkit attribute gpt2 "The capital of France is"
interpkit steer gpt2 "The weather is" --positive Love --negative Hate --at transformer.h.8
interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
interpkit decompose gpt2 "The capital of France is"
interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs-Reformatted

# Interactive HTML output
interpkit attention gpt2 "hello world" --html attention.html
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --html trace.html
interpkit attribute gpt2 "The capital of France is" --html attribution.html

# Vision models — auto-preprocessed
interpkit attribute microsoft/resnet-50 cat.jpg --target 281

Run interpkit with no arguments for a full command reference.


TransformerLens interop

Already using TransformerLens? Pass your HookedTransformer directly into InterpKit — it auto-detects the model and extracts the tokenizer:

from transformer_lens import HookedTransformer
import interpkit

tl_model = HookedTransformer.from_pretrained("gpt2")
model = interpkit.load(tl_model)

# All InterpKit operations work on TL models
model.scan("The capital of France is")
model.dla("The capital of France is")
model.trace("The Eiffel Tower is in Paris", "The Eiffel Tower is in Rome", top_k=20)
model.attention("The capital of France is", save="attention.png")
model.steer("The weather is", vector=vector, at="blocks.8", scale=2.0)

Translate between native and TL hook point names:

interpkit.to_tl_name("transformer.h.8.mlp")       # -> "blocks.8.mlp"
interpkit.to_native_name("blocks.8.attn", model.arch_info)  # -> "transformer.h.8.attn"
interpkit.list_tl_hooks(tl_model)                  # -> ["blocks.0.hook_resid_pre", ...]

Local models

import torch.nn as nn
import interpkit

my_model = MyCustomModel()
interpkit.register(my_model, layers=["blocks.0", "blocks.1"], output_head="head")
model = interpkit.load(my_model, tokenizer=my_tokenizer)
model.trace(input_a, input_b, top_k=10)

Examples

See the examples/ directory for Jupyter notebooks:

Notebook Topics
01_quickstart Inspect, scan, DLA, trace, lens, attribution, patching, ablation
02_attention_patterns Per-head heatmaps, layer filtering, HTML export
03_steering_vectors Extract and apply steering vectors at different layers/scales
04_sae_features Sparse Autoencoder feature decomposition
05_caching_and_probing Activation cache, linear probes across layers
06_model_comparison Diff two models, side-by-side tracing and logit lens
07_vision_models ResNet/ViT attribution, ablation, activations
08_dla_and_circuits DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery
09_scan_and_batch Auto-scan, batch operations, dataset workflows

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

interpkit-0.2.0.tar.gz (123.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

interpkit-0.2.0-py3-none-any.whl (108.3 kB view details)

Uploaded Python 3

File details

Details for the file interpkit-0.2.0.tar.gz.

File metadata

  • Download URL: interpkit-0.2.0.tar.gz
  • Upload date:
  • Size: 123.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for interpkit-0.2.0.tar.gz
Algorithm Hash digest
SHA256 7779e20119ab7eb59acf53e6e49467e30a38d7f4b9769c4e812869347632bec1
MD5 383e7009a61f1e6e1a4d7c56bf1c745a
BLAKE2b-256 096b74711f4ad61112599fa3d8fb4fc150063ca1cdf48cbb6e16e9669b173918

See more details on using hashes here.

File details

Details for the file interpkit-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: interpkit-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 108.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for interpkit-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7bf4851d783ecffa1c4ad85d26376ea1ec8c9b24efcb148421355bcca78a5946
MD5 e0301eb709b461f78315ade995d8221c
BLAKE2b-256 cbdeb75450d29573d5e1d577064e6b52be221e2ff574dd182b0ac94fc4b0cad6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page