Skip to main content

A package for mechanistic interpretability in Neural IR

Project description

MechIR

An IR Library for Mechanistic Interpretability

Heavily inspired by the paper Axiomatic Causal Interventions for Reverse Engineering Relevance Computation in Neural Retrieval Models by Catherine Chen, Jack Merullo, and Carsten Eickhoff (Brown University, University of Tübingen), their original code can be found here

Demonstration

Primary files can be found in the /notebooks folder. If you use our live versions and want to run our experiments make sure to choose a GPU instance of Colab. You can easily change our notebook to observe different behaviour so try your own experiments!

  • experiment.ipynb: This notebook demonstrates how to use the MechIR library to perform activation patching on a simple neural retrieval model. Here is a live version on Colab
  • activation_patching_considerations.ipynb: This notebook provides a more in-depth look at the activation patching process and the considerations that go into it. Here is a live version on Colab

Installation

Latest Release (Unstable)

pip install git+https://github.com/Parry-Parry/MechIR.git

PyPI (Stable)

pip install mechir

Usage

Models

Currently we support common bi- and cross-encoder models for neural retrieval. The following models are available:

  • Dot: A simple dot-product model allowing multiple forms of pooling
  • Cat: A cross-encoder model with joint query-document embeddings and a linear classification head
  • MonoT5: A sequence-to-sequence cross-encoder model based on T5

To load a model, for example TAS-B, you can use the following code:

from mechir import Dot

model = Dot('sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco')

Datasets and Perturbations

To probe model behaviour we need queries and documents, we retrieve these from ir-datasets though you can use your own (MechDataset). To load an IR dataset, for example MS MARCO, you can use the following code:

from mechir import MechIRDataset

dataset = MechIRDataset('msmarco-passage/dev')

The second step of probing is to create a perturbation of text to observe how model behaviour changes, we can do this simply with the perturbation decorator:

from mechir.perturb import perturbation

@perturbation
def my_perturbation(text):
    return text + "MechIR"

We can then apply this perturbation efficiently using our dataset and a torch dataloader

from torch.utils.data import DataLoader
from mechir.data import DotDataCollator

collate_fn = DotDataCollator(model.tokenizer, transformation_fun=my_perturbation)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)

Activation Patching

Activation patching is a method to isolate model behaviour to particular components of the model commonly attention heads. There are several ways to perform activation patching, a simple case is to patch all heads:

patch_output = []
for batch in dataloader:
    patch_output.append(model.patch(**batch, patch_type="head_all"))

patch_output = torch.mean(torch.stack(patch_output), axis=0)

We can then easily visualise the attention heads which activate strongly for our perturbation:

from mechir.plotting import plot_components

plot_components(patch_output)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mechir-0.0.4.tar.gz (58.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mechir-0.0.4-py3-none-any.whl (71.6 kB view details)

Uploaded Python 3

File details

Details for the file mechir-0.0.4.tar.gz.

File metadata

  • Download URL: mechir-0.0.4.tar.gz
  • Upload date:
  • Size: 58.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for mechir-0.0.4.tar.gz
Algorithm Hash digest
SHA256 c001c64165d1c5a3b167676c3c8dfa562d5ea82802ac33bf6c087ea322f30ef7
MD5 2253f6583a95f3e55957cbfee270234b
BLAKE2b-256 249c03b0c5e5b723369c0505cabdcb88e5c25fe699061823f6d818770f94e3b5

See more details on using hashes here.

File details

Details for the file mechir-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: mechir-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 71.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for mechir-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 dc7913f127fa5a2d6df8821192b137aa72430a437d050defd6f5ff465b091e0c
MD5 a9353ab855a8ca7a18bf2ef722d0c7d2
BLAKE2b-256 3787787ccd6683cf1fe0126653602d6087311b5418ab75006e57053fdbdeefba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page