Skip to main content

A framework for interpreting attention head roles in transformers via causal gating

Project description

Causal Head Gating (CHG)

Reproducible research code for interpreting attention head roles in transformer models via causal gating. Built with Python 3.10+, PyTorch, and HuggingFace Transformers for a clean, portable workflow.

Python 3.10+ PyPI License: MIT Code style: black

Table of Contents

Overview

Causal Head Gating (CHG) is a scalable method for understanding what attention heads do in transformer models. This package accompanies the paper:

Causal Head Gating: A Framework for Interpreting Roles of Attention Heads in Transformers Andrew Nam, Henry Conklin, Yukang Yang, Thomas Griffiths, Jonathan Cohen, Sarah-Jane Leslie NeurIPS 2025

Unlike traditional interpretability approaches that require hypothesis-driven analysis or specific prompt templates, CHG:

  • Learns which heads matter by training soft gates over attention heads
  • Identifies necessary heads (minimal set required for a task)
  • Identifies sufficient heads (maximal set that enables performance)
  • Classifies head roles as facilitating, interfering, or irrelevant
  • Works on any task using standard next-token prediction

Our target audience is:

  • Interpretability researchers who want a scalable, task-agnostic method for understanding attention head roles
  • ML engineers who want to identify which heads are critical for specific capabilities
  • Researchers interested in mechanistic interpretability without requiring manual circuit analysis

Installation

To install the package, you can use uv (recommended), pip, or conda.

Install from PyPI

pip install nam-causal-head-gating

Install from Source with uv (recommended)

git clone https://github.com/jonhanke/nam_causal-head-gating
cd nam_causal-head-gating
make sync  # Uses uv.lock for reproducible installs

Install from Source with pip

git clone https://github.com/jonhanke/nam_causal-head-gating
cd nam_causal-head-gating
pip install -e .

Install with Conda

git clone https://github.com/jonhanke/nam_causal-head-gating
cd nam_causal-head-gating
conda env create -f environment.yml
conda activate causal-head-gating
pip install -e .

Optional Dependencies

# For visualization
pip install nam-causal-head-gating[viz]

# For HuggingFace datasets integration
pip install nam-causal-head-gating[datasets]

# For Jupyter notebooks
pip install nam-causal-head-gating[notebooks]

# For development
pip install nam-causal-head-gating[dev]

# Everything
pip install nam-causal-head-gating[all]

Running Jupyter Notebooks

After installation, you can run the example notebooks:

# Using make (recommended)
make notebook    # Starts Jupyter Notebook
make lab         # Starts JupyterLab

# Or manually
jupyter notebook --notebook-dir=notebooks

If using a virtual environment, ensure the kernel is registered:

# Register the virtual environment as a Jupyter kernel
python -m ipykernel install --user --name=causal-head-gating --display-name="Python (CHG)"

Then select "Python (CHG)" as the kernel when opening notebooks.

Quick Start

from causal_head_gating import CHGAnalyzer

# Load any HuggingFace model
analyzer = CHGAnalyzer.from_pretrained("meta-llama/Llama-3.2-1B")

# Analyze on your data
results = analyzer.fit(
    texts=["The capital of France is", "2 + 2 equals"],
    targets=["Paris", "4"],
)

# Get insights
necessary = results.necessary_heads()
print(f"Found {len(necessary)} necessary heads")

# View head taxonomy
taxonomy = results.head_taxonomy()
print(taxonomy)

Project Structure

nam_causal-head-gating/
├── pyproject.toml                    # Package configuration
├── uv.lock                           # Reproducible installs (uv)
├── environment.yml                   # Conda environment
├── Makefile                          # Convenience commands
├── README.md
├── LICENSE
├── notebooks/
│   ├── config.yaml                   # Local paths configuration
│   ├── chg_example.ipynb             # CHG training example
│   └── datasets/
│       ├── aba_abb.ipynb             # ABA/ABB dataset preparation
│       └── math.ipynb                # Math dataset preparation
├── slurm/
│   ├── job_example.slurm             # Slurm job template
│   ├── config_example.yaml           # Cluster config template
│   ├── run_chg_example.py            # Training script for HPC
│   └── setup_env.sh                  # Environment setup
├── src/
│   └── causal_head_gating/
│       ├── __init__.py               # Public API exports
│       ├── api.py                    # High-level CHGAnalyzer API
│       ├── core/
│       │   ├── chg.py                # Core CHG class (hook-based gating)
│       │   └── trainer.py            # Three-stage training pipeline
│       ├── data/
│       │   ├── datasets.py           # MaskedSequenceDataset, TensorDict
│       │   ├── formatters.py         # Few-shot prompt formatting
│       │   └── tokenization.py       # PromptTokenizer utilities
│       ├── models/
│       │   └── adapters.py           # Model architecture adapters
│       ├── analysis/
│       │   └── masks.py              # Mask analysis utilities
│       └── utils/
│           ├── helpers.py            # to_long_df and other helpers
│           └── tensor_dict.py        # TensorDict implementation
└── tests/
    └── test_imports.py               # Import tests

Note: Dataset files (e.g., aba_abb.tsv) are hosted on HuggingFace Hub and downloaded automatically on first use.

Core Concepts

Three-Stage Training

CHG training proceeds in three stages:

  1. Unregularized: Learn task-relevant masks without sparsity bias
  2. Positive L1: Find minimal necessary heads (pushes masks → 0)
  3. Negative L1: Find maximal sufficient heads (pushes masks → 1)

Head Taxonomy

Based on necessary and sufficient masks, heads are classified as:

Classification Necessary Sufficient Interpretation
Facilitating High High Helps the task
Interfering Low High Hurts the task
Irrelevant Low Low No effect

Loss Masking

CHG uses loss masks to specify which tokens to supervise:

  • For pattern tasks: mask only the final prediction token
  • For generation tasks: mask the entire target sequence
  • For few-shot learning: mask only target tokens, not in-context examples

Usage Examples

Pattern Task (Finding Induction Heads)

# ABA pattern tests for induction/copying behavior
results = analyzer.fit_pattern_task("aba", num_samples=5000)
print(results.necessary_heads())  # These are likely induction heads

HuggingFace Dataset

results = analyzer.fit_huggingface(
    dataset_name="openai/gsm8k",
    input_column="question",
    target_column="answer",
    max_samples=1000,
)

Custom Dataset

from causal_head_gating import CHGDataset

dataset = CHGDataset.from_texts(
    texts=my_prompts,
    targets=my_targets,
    tokenizer=analyzer.tokenizer,
)
results = analyzer.fit_dataset(dataset)

Loading the ABA/ABB Dataset

from causal_head_gating import CHGDataset
from causal_head_gating.data import get_aba_abb_path

# Download from HuggingFace (cached after first use)
data_path = get_aba_abb_path()

# Load with last-token-only supervision
dataset = CHGDataset.from_tsv(
    str(data_path),
    tokenizer=tokenizer,
    prompt_column="prompt",
    target_column="target",
    last_token_only=True,  # Only supervise final token
)

Few-Shot Math Dataset

from causal_head_gating.data import load_math_dataset

# Full workflow: load, filter, create few-shot prompts
df_prompts, input_ids, loss_masks = load_math_dataset(
    tokenizer=tokenizer,
    num_examples=50,        # Few-shot examples
    num_train=50000,        # Training samples
    num_validation=5000,    # Validation samples
)

Advanced: Low-Level API

from causal_head_gating import CHG, CHGTrainer

# Wrap your model
chg = CHG(model)

# Create and use masks directly
masks = chg.create_masks(num_masks=10)
logp = chg.get_logp(masks.sigmoid(), input_ids, loss_masks)

# Or use the trainer for full control
trainer = CHGTrainer(
    model, dataset,
    num_masks=10,
    l1_weight=0.15,
    gradient_accum_steps=4,
)

for mask, metrics in trainer.fit(num_updates=500, num_reg_updates=500):
    print(f"Stage: {metrics['regularization']}, NLL: {metrics['nll']:.3f}")

HPC Deployment

For running CHG on Slurm-managed HPC clusters, we provide ready-to-use job scripts in the slurm/ directory.

Quick Start

cd slurm

# 1. Configure for your cluster
cp job_example.slurm job.slurm
cp config_example.yaml config.yaml
# Edit job.slurm: set --partition and --account for your cluster
# Edit config.yaml: set huggingface cache path

# 2. Download model on login node (compute nodes often lack internet)
huggingface-cli download meta-llama/Llama-3.2-1B

# 3. Submit job
sbatch job.slurm

Features

  • Offline-ready: Handles clusters where compute nodes have no internet access
  • Configurable: Easy customization of model, dataset, and training parameters
  • GPU-optimized: Mixed precision training, tested on A100/H100/H200

See slurm/README.md for detailed setup instructions, troubleshooting, and cluster-specific configuration.

Supported Models

CHG supports most HuggingFace transformer architectures:

Architecture Models
Llama Llama-2, Llama-3, Llama-3.1, Llama-3.2
Mistral Mistral, Mixtral
GPT-2 GPT-2 (all sizes)
GPT-NeoX Pythia
Falcon Falcon
Qwen Qwen2, Qwen2.5
Gemma Gemma, Gemma-2
Phi Phi, Phi-3

To add support for other architectures, subclass ModelAdapter.

API Reference

CHGAnalyzer

High-level interface for CHG analysis.

CHGAnalyzer.from_pretrained(model_name, device=None, torch_dtype=None)
analyzer.fit(texts, targets, num_masks=10, num_updates=500, ...)
analyzer.fit_pattern_task(pattern="aba", num_samples=10000, ...)
analyzer.fit_huggingface(dataset_name, input_column, target_column, ...)
analyzer.fit_dataset(dataset, ...)

CHGResults

Container for analysis results.

results.necessary_heads(threshold=0.5)  # List of (layer, head) tuples
results.sufficient_heads(threshold=0.5)
results.head_taxonomy()  # DataFrame with classifications
results.summary()  # Dict with statistics
results.to_dataframe()  # Long-format DataFrame
results.save(path) / CHGResults.load(path)

CHGDataset

Factory for creating training datasets.

CHGDataset.from_texts(texts, targets, tokenizer)
CHGDataset.from_tokens(input_ids, loss_masks, pad_token_id)
CHGDataset.from_huggingface(dataset_name, tokenizer, input_column, target_column)
CHGDataset.from_pattern_task(pattern, tokenizer, num_samples)
CHGDataset.from_tsv(path, tokenizer, prompt_column, target_column, last_token_only=False)

Utility Functions

from causal_head_gating.utils import to_long_df
from causal_head_gating.data import load_math_dataset, create_fewshot_dataset, get_aba_abb_path

Citation

If you use this code in your research, please cite:

@inproceedings{nam2025causal,
  title={Causal Head Gating: A Framework for Interpreting Roles of Attention Heads in Transformers},
  author={Nam, Andrew and Conklin, Henry and Yang, Yukang and Griffiths, Thomas and Cohen, Jonathan and Leslie, Sarah-Jane},
  booktitle={Advances in Neural Information Processing Systems},
  year={2025}
}

License

MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nam_causal_head_gating-0.1.0.tar.gz (47.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nam_causal_head_gating-0.1.0-py3-none-any.whl (45.2 kB view details)

Uploaded Python 3

File details

Details for the file nam_causal_head_gating-0.1.0.tar.gz.

File metadata

  • Download URL: nam_causal_head_gating-0.1.0.tar.gz
  • Upload date:
  • Size: 47.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for nam_causal_head_gating-0.1.0.tar.gz
Algorithm Hash digest
SHA256 5af4465c091d7fd390db7f14507be97cf0209460c5e1c2d886cb2a1ece001038
MD5 30a4ecc4bf8403560f1b75f4d475795b
BLAKE2b-256 035599d60b1f3a96c497ee0a2941f582b5660f205fc37f3e651fd6f7e41a72fd

See more details on using hashes here.

File details

Details for the file nam_causal_head_gating-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for nam_causal_head_gating-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2a5c8bb1df2fc74904c6e1ab4d42fcc4c8d64bb7321ac7f6b6807fb8d212b958
MD5 93bf39f9b139984165d144646fc7792b
BLAKE2b-256 dd4186789bacf07986eba44348d0ea46e65dd08b99b6a85734b350c3a1c3d136

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page