A package for mechanistic interpretability in Neural IR
Project description
MechIR
An IR Library for Mechanistic Interpretability
Heavily inspired by the paper Axiomatic Causal Interventions for Reverse Engineering Relevance Computation in Neural Retrieval Models by Catherine Chen, Jack Merullo, and Carsten Eickhoff (Brown University, University of Tübingen), their original code can be found here
Demonstration
Primary files can be found in the /notebooks folder. If you use our live versions and want to run our experiments make sure to choose a GPU instance of Colab. You can easily change our notebook to observe different behaviour so try your own experiments!
- experiment.ipynb: This notebook demonstrates how to use the MechIR library to perform activation patching on a simple neural retrieval model. Here is a live version on Colab
- activation_patching_considerations.ipynb: This notebook provides a more in-depth look at the activation patching process and the considerations that go into it. Here is a live version on Colab
Installation
Latest Release (Unstable)
pip install git+https://github.com/Parry-Parry/MechIR.git
PyPI (Stable)
pip install mechir
Usage
Models
Currently we support common bi- and cross-encoder models for neural retrieval. The following models are available:
- Dot: A simple dot-product model allowing multiple forms of pooling
- Cat: A cross-encoder model with joint query-document embeddings and a linear classification head
- MonoT5: A sequence-to-sequence cross-encoder model based on T5
To load a model, for example TAS-B, you can use the following code:
from mechir import Dot
model = Dot('sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco')
Datasets and Perturbations
To probe model behaviour we need queries and documents, we retrieve these from ir-datasets though you can use your own (MechDataset). To load an IR dataset, for example MS MARCO, you can use the following code:
from mechir import MechIRDataset
dataset = MechIRDataset('msmarco-passage/dev')
The second step of probing is to create a perturbation of text to observe how model behaviour changes, we can do this simply with the perturbation decorator:
from mechir.perturb import perturbation
@perturbation
def my_perturbation(text):
return text + "MechIR"
We can then apply this perturbation efficiently using our dataset and a torch dataloader
from torch.utils.data import DataLoader
from mechir.data import DotDataCollator
collate_fn = DotDataCollator(model.tokenizer, transformation_fun=my_perturbation)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)
Activation Patching
Activation patching is a method to isolate model behaviour to particular components of the model commonly attention heads. There are several ways to perform activation patching, a simple case is to patch all heads:
patch_output = []
for batch in dataloader:
patch_output.append(model.patch(**batch, patch_type="head_all"))
patch_output = torch.mean(torch.stack(patch_output), axis=0)
We can then easily visualise the attention heads which activate strongly for our perturbation:
from mechir.plotting import plot_components
plot_components(patch_output)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mechir-0.0.4.tar.gz.
File metadata
- Download URL: mechir-0.0.4.tar.gz
- Upload date:
- Size: 58.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c001c64165d1c5a3b167676c3c8dfa562d5ea82802ac33bf6c087ea322f30ef7
|
|
| MD5 |
2253f6583a95f3e55957cbfee270234b
|
|
| BLAKE2b-256 |
249c03b0c5e5b723369c0505cabdcb88e5c25fe699061823f6d818770f94e3b5
|
File details
Details for the file mechir-0.0.4-py3-none-any.whl.
File metadata
- Download URL: mechir-0.0.4-py3-none-any.whl
- Upload date:
- Size: 71.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dc7913f127fa5a2d6df8821192b137aa72430a437d050defd6f5ff465b091e0c
|
|
| MD5 |
a9353ab855a8ca7a18bf2ef722d0c7d2
|
|
| BLAKE2b-256 |
3787787ccd6683cf1fe0126653602d6087311b5418ab75006e57053fdbdeefba
|