Skip to main content

A package for mechanistic interpretability in Neural IR

Project description

MechIR

An IR Library for Mechanistic Interpretability

Heavily inspired by the paper Axiomatic Causal Interventions for Reverse Engineering Relevance Computation in Neural Retrieval Models by Catherine Chen, Jack Merullo, and Carsten Eickhoff (Brown University, University of Tübingen), their original code can be found here

Demonstration

Primary files can be found in the /notebooks folder. If you use our live versions and want to run our experiments make sure to choose a GPU instance of Colab. You can easily change our notebook to observe different behaviour so try your own experiments!

  • experiment.ipynb: This notebook demonstrates how to use the MechIR library to perform activation patching on a simple neural retrieval model. Here is a live version on Colab
  • activation_patching_considerations.ipynb: This notebook provides a more in-depth look at the activation patching process and the considerations that go into it. Here is a live version on Colab

Installation

Latest Release (Unstable)

pip install git+https://github.com/Parry-Parry/MechIR.git

PyPI (Stable)

pip install mechir

Usage

Models

Currently we support common bi- and cross-encoder models for neural retrieval. The following models are available:

  • Dot: A simple dot-product model allowing multiple forms of pooling
  • Cat: A cross-encoder model with joint query-document embeddings and a linear classification head
  • MonoT5: A sequence-to-sequence cross-encoder model based on T5

To load a model, for example TAS-B, you can use the following code:

from mechir import Dot

model = Dot('sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco')

Datasets and Perturbations

To probe model behaviour we need queries and documents, we retrieve these from ir-datasets though you can use your own (MechDataset). To load an IR dataset, for example MS MARCO, you can use the following code:

from mechir import MechIRDataset

dataset = MechIRDataset('msmarco-passage/dev')

The second step of probing is to create a perturbation of text to observe how model behaviour changes, we can do this simply with the perturbation decorator:

from mechir.perturb import perturbation

@perturbation
def my_perturbation(text):
    return text + "MechIR"

We can then apply this perturbation efficiently using our dataset and a torch dataloader

from torch.utils.data import DataLoader
from mechir.data import DotDataCollator

collate_fn = DotDataCollator(model.tokenizer, transformation_fun=my_perturbation)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)

Activation Patching

Activation patching is a method to isolate model behaviour to particular components of the model commonly attention heads. There are several ways to perform activation patching, a simple case is to patch all heads:

patch_output = []
for batch in dataloader:
    patch_output.append(model(**batch, patch_type="head_all"))

patch_output = torch.mean(torch.stack(patch_output), axis=0)

We can then easily visualise the attention heads which activate strongly for our perturbation:

from mechir.plotting import plot_components

plot_components(patch_output)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mechir-0.0.2.tar.gz (56.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mechir-0.0.2-py3-none-any.whl (74.1 kB view details)

Uploaded Python 3

File details

Details for the file mechir-0.0.2.tar.gz.

File metadata

  • Download URL: mechir-0.0.2.tar.gz
  • Upload date:
  • Size: 56.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for mechir-0.0.2.tar.gz
Algorithm Hash digest
SHA256 227b14cefbd8fa0abf31f3659741f25482c3506cc4becd15afef9222bb3e0630
MD5 dcd6303cbfbedc207e31998891e9f2b4
BLAKE2b-256 d14c7a395873f110f1c9fdabf82b41c41ce4edb107d87ae1f32ce571bbc2bc3a

See more details on using hashes here.

File details

Details for the file mechir-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: mechir-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 74.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for mechir-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3ed137db605f7e0b8688d464f751ddb9c04e39ca55e3d6de2d793002fbce6aaf
MD5 a56dc4b5aa3024a4f336ca192286e433
BLAKE2b-256 9f6ae64e0806e9c5ae4b9d6d0b6e0c4a41e4f6146ad683dccb1fb3c410fead40

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page