Skip to main content

A package for mechanistic interpretability in Neural IR

Project description

MechIR

An IR Library for Mechanistic Interpretability

Heavily inspired by the paper Axiomatic Causal Interventions for Reverse Engineering Relevance Computation in Neural Retrieval Models by Catherine Chen, Jack Merullo, and Carsten Eickhoff (Brown University, University of Tübingen), their original code can be found here

Demonstration

Primary files can be found in the /notebooks folder. If you use our live versions and want to run our experiments make sure to choose a GPU instance of Colab. You can easily change our notebook to observe different behaviour so try your own experiments!

  • experiment.ipynb: This notebook demonstrates how to use the MechIR library to perform activation patching on a simple neural retrieval model. Here is a live version on Colab
  • activation_patching_considerations.ipynb: This notebook provides a more in-depth look at the activation patching process and the considerations that go into it. Here is a live version on Colab

Installation

Latest Release (Unstable)

pip install git+https://github.com/Parry-Parry/MechIR.git

PyPI (Stable)

pip install mechir

Usage

Models

Currently we support common bi- and cross-encoder models for neural retrieval. The following models are available:

  • Dot: A simple dot-product model allowing multiple forms of pooling
  • Cat: A cross-encoder model with joint query-document embeddings and a linear classification head
  • MonoT5: A sequence-to-sequence cross-encoder model based on T5

To load a model, for example TAS-B, you can use the following code:

from mechir import Dot

model = Dot('sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco')

Datasets and Perturbations

To probe model behaviour we need queries and documents, we retrieve these from ir-datasets though you can use your own (MechDataset). To load an IR dataset, for example MS MARCO, you can use the following code:

from mechir import MechIRDataset

dataset = MechIRDataset('msmarco-passage/dev')

The second step of probing is to create a perturbation of text to observe how model behaviour changes, we can do this simply with the perturbation decorator:

from mechir.perturb import perturbation

@perturbation
def my_perturbation(text):
    return text + "MechIR"

We can then apply this perturbation efficiently using our dataset and a torch dataloader

from torch.utils.data import DataLoader
from mechir.data import DotDataCollator

collate_fn = DotDataCollator(model.tokenizer, transformation_fun=my_perturbation)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=8)

Activation Patching

Activation patching is a method to isolate model behaviour to particular components of the model commonly attention heads. There are several ways to perform activation patching, a simple case is to patch all heads:

patch_output = []
for batch in dataloader:
    patch_output.append(model.patch(**batch, patch_type="head_all"))

patch_output = torch.mean(torch.stack(patch_output), axis=0)

We can then easily visualise the attention heads which activate strongly for our perturbation:

from mechir.plotting import plot_components

plot_components(patch_output)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mechir-0.0.3.tar.gz (56.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mechir-0.0.3-py3-none-any.whl (74.2 kB view details)

Uploaded Python 3

File details

Details for the file mechir-0.0.3.tar.gz.

File metadata

  • Download URL: mechir-0.0.3.tar.gz
  • Upload date:
  • Size: 56.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for mechir-0.0.3.tar.gz
Algorithm Hash digest
SHA256 a8a1b533363d86b92e99e25ebcaff773158e65f33160015be65efa75073e6745
MD5 ba8c7a1dc4356b166866a1ac2b95cffe
BLAKE2b-256 a3ba8dec509730f4dba38e7399ef93e5dde036e62827a9af8e9bf5834e9092c3

See more details on using hashes here.

File details

Details for the file mechir-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: mechir-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 74.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for mechir-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 8f7c5056fb9b0332be89245b518b4f985189f6e5c7f76fb59102e7ef9e20be53
MD5 662c9756fc0f3b3d66c0dab655772b57
BLAKE2b-256 4a58558a1611fc2309417eeebe5f85a69d90004dbac6fa646b46de209f47e2bf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page