Skip to main content

IRTK: Interpretability Research Toolkit — mechanistic interpretability with JAX and Equinox

Project description

IRTK — Interpretability Research Toolkit

A comprehensive mechanistic interpretability library built on JAX and Equinox. Originally inspired by TransformerLens, IRTK provides 170+ analysis modules for understanding transformer internals — from basic logit lenses and activation patching to circuit discovery, sparse autoencoders, causal scrubbing, and representation engineering.

Initially vibe-coded by Opus, so YMMV. PRs welcome.

Installation

git clone git@github.com:danielpcox/irtk.git
cd irtk
uv sync --extra dev

This installs all core dependencies (JAX, Equinox, Transformers, NumPy, Optax, Matplotlib) plus dev dependencies (pytest, torch for HuggingFace weight loading).

Requirements: Python 3.10+, uv.

Quick start

import irtk

# Load a pretrained model (GPT-2, GPT-Neo, GPT-NeoX, LLaMA, Mistral)
model = irtk.HookedTransformer.from_pretrained("gpt2")

# Run with activation caching
logits, cache = model.run_with_cache(tokens)

# Logit lens — decode residual stream at each layer
irtk.logit_lens.logit_lens(model, tokens, cache)

# Activation patching
irtk.patching.activation_patch(model, clean_tokens, corrupt_tokens, cache)

# Circuit analysis
irtk.circuits.direct_logit_attribution(model, tokens, cache)

# Attention head taxonomy
irtk.head_analysis.find_induction_heads(model, tokens, cache)

Architecture

IRTK models are functional and immutable Equinox modules:

  • Unbatched components — individual components operate on single examples; use jax.vmap for batches.
  • HookState — a plain Python object (not a PyTree) threaded through __call__ for activation caching and intervention.
  • TransformerLens weight shapesW_Q, W_K, W_V are [n_heads, d_model, d_head] for direct per-head access.
  • Functional weight loading — HuggingFace weights are loaded via eqx.tree_at chains (no mutation).

Core classes

Class Description
HookedTransformer Main model class with hook points at every intermediate computation
HookedTransformerConfig Model configuration (dimensions, heads, layers, vocab, etc.)
ActivationCache Dictionary-like cache of intermediate activations
FactoredMatrix Efficient representation of A @ B without materializing the product
HookPoint Attachment point for activation access and intervention

Supported model families

Family Architectures
GPT-2 gpt2, gpt2-medium, gpt2-large, gpt2-xl
GPT-Neo EleutherAI/gpt-neo-125M, etc.
GPT-NeoX EleutherAI/gpt-neox-20b, EleutherAI/pythia-*
LLaMA / Mistral LLaMA 2, LLaMA 3, Mistral 7B (with GatedMLP, RoPE)

Module reference

IRTK's 170+ modules cover the full spectrum of mechanistic interpretability research. Every module is importable from the top-level package (irtk.<module>).

Probing & representation analysis

Module Key functions
probes Linear/regression probes, train_linear_probe
activation_probing Multiclass probe, nonlinear probe, concept localization
sparse_probing L1 probes, sparse concept directions, feature selection
probing_dynamics Probe accuracy by layer, emergence threshold, selectivity
distributed_reps Representation rank, cross-layer concept tracking
embeddings Embedding similarity, PCA, analogies, clustering, W_E/W_U alignment
embedding_dynamics Identity decay, semantic drift, context mixing

Logit lens & prediction

Module Key functions
logit_lens Standard logit lens, tuned lens
logit_lens_variants Contrastive, causal, residual contribution lens
prediction_entropy Layer prediction entropy, commit depth, surprisal
prediction_geometry Vocabulary trajectory, sharpening rate
prediction_confidence Confidence profile, layerwise evolution
logit_dynamics Logit flips, prediction stability, commitment timing
token_prediction_analysis Confidence, surprisal, difficulty, rank trajectory

Activation patching & ablation

Module Key functions
patching Activation patching, ablation, path patching
activation_patching_variants Denoising/noising patching, mean/resample ablation
ablation_study Layer-by-layer ablation, head importance, position sensitivity
token_level_ablation Per-token knockout, necessity/sufficiency, minimal set
token_erasure Erasure effects, pairwise interaction, layerwise

Circuit analysis & discovery

Module Key functions
circuits OV/QK circuits, composition scores, direct logit attribution
circuit_discovery ACDC-style edge attribution, iterative pruning
circuit_evaluation Faithfulness, completeness, minimality, circuit IoU
circuit_motifs Skip trigram, negative mover, backup circuits
causal_scrubbing Causal scrub, interchange intervention, path patching
causal_abstraction Interchange intervention, DAS, multi-variable alignment
path_expansion Path enumeration, virtual weights, contribution matrix
subnetwork_analysis Circuit extraction, faithfulness, greedy search
induction_circuit_analysis Full path tracing, matching scores, copy verification

Attention analysis

Module Key functions
attention_utils Entropy, similarity, causal tracing, attention flow
attention_surgery Knockout, pattern patching, force attention
attention_rollout Rollout, flow, effective attention, per-head rollout
attention_composition QK/V composition scores, path tracing, virtual patterns
head_analysis Induction/prev-token head finding, scoring, classification
attention_head_taxonomy Induction, copy, inhibition scoring, full taxonomy
attention_sparsity Entropy profile, sparse/dense heads, window analysis
attention_motif_discovery SVD-based motif extraction, function inference
attention_sink_analysis Sink identification, magnitude tracking, removal impact
attention_distance Mean distance, local/global heads, receptive field

Sparse autoencoders & features

Module Key functions
sae SparseAutoencoder, training, feature analysis
sparse_features Feature activation examples, feature circuits
sae_feature_steering SAE feature-level steering, clamping
transcoder Transcoder class, MLP feature logit attribution
cross_coder CrossCoder, joint training, shared vs specific features
feature_geometry Dictionary geometry: splitting, absorption, universality
polysemantic_features Polysemanticity score, context clusters
attribution_graphs Feature-to-feature computation graphs, pruning

Steering & intervention

Module Key functions
steering Add/subtract/compute steering vectors, steer_generation
representation_engineering RepE reading vectors, control vectors, suppression curves
activation_surgery Clamp, scale, project, rotate, replace activations
layerwise_intervention Activation addition, direction intervention
intervention_effects Scaling sensitivity, knockout recovery, transferability

Attribution & gradients

Module Key functions
gradients Grad x input, integrated gradients, Jacobian
feature_attribution Token-to-neuron, logit decomposition, cross-layer
principled_attribution Shapley values, KernelSHAP, interaction index
logit_attribution_decomposition Full logit decomposition, promoted/demoted tokens
residual_attribution Per-component contribution, interference, decomposition

MLP & neuron analysis

Module Key functions
neurons Neuron stats, top activating tokens, neuron-to-logit
mlp_decomposition Neuron contributions, knowledge storage, nonlinearity
mlp_knowledge_editing Fact localization, rank-one edit, side effects
mlp_gating_analysis Gate patterns, selectivity, contribution decomposition
knowledge Knowledge neurons, causal tracing, fact editing vectors
model_editing ROME-style fact editing, rank-one MLP updates

Geometry & similarity

Module Key functions
geometry CKA, subspace overlap, intrinsic dimensionality
activation_geometry Manifold dimension, clustering, curvature, norms
residual_stream Cosine to unembed, norm tracking, prediction trajectory
superposition Feature directions, interference, dimensionality
low_rank Weight SVD, effective rank, spectrum similarity
representation_similarity CKA, component similarity, drift

Model comparison & cross-model

Module Key functions
model_diff Weight diff, activation diff, logit diff on dataset
model_comparison Weight distance, prediction agreement, importance ranking
cross_model_alignment CKA correspondence, head matching, circuit universality
behavior_alignment Multi-model comparison, interpretability transfer

Weight analysis

Module Key functions
weight_processing Fold LayerNorm, center writing weights, center unembed
weight_importance Fisher information, magnitude pruning, lottery ticket mask
weight_structure Spectral analysis, parameter utilization, weight norms
weight_decomposition SVD, clustering, spectral analysis, norm distribution

Training & development

Module Key functions
training Algorithmic datasets, train_tiny_model, checkpoint analysis
training_dynamics Phase transitions, grokking, circuit formation
developmental_interpretability Crystallization, grokking dynamics

Safety & robustness

Module Key functions
safety_relevant_features Refusal direction, deception detection, alignment circuits
memorization_detection Memorization scores, extractability, trigger localization
mechanistic_anomaly_detection Activation profiles, trojan signatures
robustness_perturbation Weight noise tolerance, mode connectivity
counterfactual Contrastive activation diff, necessity/sufficiency

Visualization

Module Key functions
visualization 19 plot functions: attention patterns, logit attribution, causal tracing, prediction trajectory, and more

Generation

Module Key functions
generation generate, generate_with_cache, top-k/nucleus sampling

Notebooks

The notebooks/ directory contains 350+ Jupyter notebooks with worked examples and API demos covering every module.

Start here: 00_getting_started.ipynb — a complete mechinterp investigation walkthrough from loading a model through logit lens, activation patching, and circuit analysis.

To run notebooks:

uv run jupyter lab notebooks/

If you don't have Jupyter installed yet:

uv add --dev jupyterlab
uv run jupyter lab notebooks/

Recommended reading order:

Notebook Topic
00_getting_started End-to-end mechinterp workflow
01_api_cheatsheet Quick reference for every module
02_transformer_anatomy Understanding transformer components
03_logit_lens Logit lens and tuned lens
04_ioi_patching Activation patching on IOI
05_linear_probes Training linear probes
06_sparse_autoencoders SAE feature discovery
07_circuit_analysis OV/QK circuits and composition
08_gradient_interpretability Gradient-based attribution
09_automatic_circuit_discovery ACDC-style circuit finding

The remaining 340+ notebooks cover individual modules in depth — each module in the module reference above has a corresponding notebook.

Testing

uv run pytest                       # run full suite (4000+ tests)
uv run pytest irtk/tests/ -x -q     # quick smoke test, stop on first failure

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

irtk_jax-0.1.0.tar.gz (775.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

irtk_jax-0.1.0-py3-none-any.whl (847.7 kB view details)

Uploaded Python 3

File details

Details for the file irtk_jax-0.1.0.tar.gz.

File metadata

  • Download URL: irtk_jax-0.1.0.tar.gz
  • Upload date:
  • Size: 775.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for irtk_jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 edff1a053596137c13ec83920e5249a8be6b759757cd00a9cefa3639631d09a3
MD5 a0effe28f2712c7cbd2a52a696190bb3
BLAKE2b-256 0a45fbeaa4ab3b82e510c8b9b491edd7409c6c32fea35a23554f39dbdbcd2e2b

See more details on using hashes here.

File details

Details for the file irtk_jax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: irtk_jax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 847.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for irtk_jax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6d810a71ff09ae48a597cad28f87cbac4899a69075b391245c29c44edaca20f1
MD5 89fbf3d88ce3d751967a327760dfe498
BLAKE2b-256 af904df961cdd321251d1cc10b0ba1d6a665dbbbd67f169e981ba3ec86389a98

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page