Skip to main content

Use Activation Intervention to Interpret Causal Mechanism of Model

Project description


Use Activation Intervention to Interpret Causal Mechanism of Model

pyvene supports customizable interventions on different neural architectures (e.g., RNN or Transformers). It supports complex intervention schemas (e.g., parallel or serialized interventions) and a wide range of intervention modes (e.g., static or trained interventions) at scale to gain interpretability insights.

Getting Started: [pyvene 101]

Installation

pip install pyvene

Wrap , Intervene and Share

You can intervene with supported models as,

import pyvene
from pyvene import IntervenableRepresentationConfig, IntervenableConfig, IntervenableModel

# provided wrapper for huggingface gpt2 model
_, tokenizer, gpt2 = pyvene.create_gpt2()

# turn gpt2 into intervenable_gpt2
intervenable_gpt2 = IntervenableModel(
    intervenable_config = IntervenableConfig(
        intervenable_representations=[
            IntervenableRepresentationConfig(
                0,            # intervening layer 0
                "mlp_output", # intervening mlp output
                "pos",        # intervening based on positional indices of tokens
                1             # maximally intervening one token
            ),
        ],
    ), 
    model = gpt2
)

# intervene base with sources on the fourth token.
original_outputs, intervened_outputs = intervenable_gpt2(
    tokenizer("The capital of Spain is", return_tensors="pt"),
    [tokenizer("The capital of Italy is", return_tensors="pt")],
    {"sources->base": ([[[4]]], [[[4]]])}
)
original_outputs.last_hidden_state - intervened_outputs.last_hidden_state

which returns,

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0008, -0.0078, -0.0066,  ...,  0.0007, -0.0018,  0.0060]]])

showing that we have causal effects only on the last token as expected. You can share your interventions through Huggingface with others with a single call,

intervenable_gpt2.save(
    save_directory="./your_gpt2_mounting_point/",
    save_to_hf_hub=True,
    hf_repo_name="your_gpt2_mounting_point",
)

We see interventions are knobs that can mount on models. And people can share their knobs with others to share knowledge about how to steer models. You can try this at [Intervention Sharing]

You can also use the intervenable_gpt2 just like a regular torch model component inside another model, or another pipeline as,

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union, Dict

class ModelWithIntervenables(nn.Module):
    def __init__(self):
        super(ModelWithIntervenables, self).__init__()
        self.intervenable_gpt2 = intervenable_gpt2
        self.relu = nn.ReLU()
        self.fc = nn.Linear(768, 1)
        # Your other downstream components go here

    def forward(
        self, 
        base,
        sources: Optional[List] = None,
        unit_locations: Optional[Dict] = None,
        activations_sources: Optional[Dict] = None,
        subspaces: Optional[List] = None,
    ):
        _, counterfactual_x = self.intervenable_gpt2(
            base,
            sources,
            unit_locations,
            activations_sources,
            subspaces
        )
        counterfactual_x = counterfactual_x.last_hidden_state
        
        counterfactual_x = self.relu(counterfactual_x)
        counterfactual_x = self.fc(counterfactual_x)
        return counterfactual_x

Selected Tutorials

Level Tutorial Run in Colab Description
Beginner Getting Started Introduces basic static intervention on factual recall examples
Beginner Intervened Model Generation Shows how to intervene a model during generation
Intermediate Intervene Your Local Models Illustrates how to run this library with your own models
Advanced Trainable Interventions for Causal Abstraction Illustrates how to train an intervention to discover causal mechanisms of a neural model

Causal Abstraction: From Interventions to Gain Interpretability Insights

Basic interventions are fun but we cannot make any causal claim systematically. To gain actual interpretability insights, we want to measure the counterfactual behaviors of a model in a data-driven fashion. In other words, if the model responds systematically to your interventions, then you start to associate certain regions in the network with a high-level concept. We also call this alignment search process with model internals.

Understanding Causal Mechanisms with Static Interventions

Here is a more concrete example,

def add_three_numbers(a, b, c):
    var_x = a + b
    return var_x + c

The function solves a 3-digit sum problem. Let's say, we trained a neural network to solve this problem perfectly. "Can we find the representation of (a + b) in the neural network?". We can use this library to answer this question. Specifically, we can do the following,

  • Step 1: Form Interpretability (Alignment) Hypothesis: We hypothesize that a set of neurons N aligns with (a + b).
  • Step 2: Counterfactual Testings: If our hypothesis is correct, then swapping neurons N between examples would give us expected counterfactual behaviors. For instance, the values of N for (1+2)+3, when swapping with N for (2+3)+4, the output should be (2+3)+3 or (1+2)+4 depending on the direction of the swap.
  • Step 3: Reject Sampling of Hypothesis: Running tests multiple times and aggregating statistics in terms of counterfactual behavior matching. Proposing a new hypothesis based on the results.

To translate the above steps into API calls with the library, it will be a single call,

intervenable.evaluate(
    train_dataloader=test_dataloader,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you provide testing data (basically interventional data and the counterfactual behavior you are looking for) along with your metrics functions. The library will try to evaluate the alignment with the intervention you specified in the config.


Understanding Causal Mechanism with Trainable Interventions

The alignment searching process outlined above can be tedious when your neural network is large. For a single hypothesized alignment, you basically need to set up different intervention configs targeting different layers and positions to verify your hypothesis. Instead of doing this brute-force search process, you can turn it into an optimization problem which also has other benefits such as distributed alignments.

In its crux, we basically want to train an intervention to have our desired counterfactual behaviors in mind. And if we can indeed train such interventions, we claim that causally informative information should live in the intervening representations! Below, we show one type of trainable intervention models.interventions.RotatedSpaceIntervention as,

class RotatedSpaceIntervention(TrainableIntervention):
    
    """Intervention in the rotated space."""
    def forward(self, base, source):
        rotated_base = self.rotate_layer(base)
        rotated_source = self.rotate_layer(source)
        # interchange
        rotated_base[:self.interchange_dim] = rotated_source[:self.interchange_dim]
        # inverse base
        output = torch.matmul(rotated_base, self.rotate_layer.weight.T)
        return output

Instead of activation swapping in the original representation space, we first rotate them, and then do the swap followed by un-rotating the intervened representation. Additionally, we try to use SGD to learn a rotation that lets us produce expected counterfactual behavior. If we can find such rotation, we claim there is an alignment. If the cost is between X and Y.ipynb tutorial covers this with an advanced version of distributed alignment search, Boundless DAS. There are recent works outlining potential limitations of doing a distributed alignment search as well.

You can now also make a single API call to train your intervention,

intervenable.train(
    train_dataloader=train_dataloader,
    compute_loss=compute_loss,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you need to pass in a trainable dataset, and your customized loss and metrics function. The trainable interventions can later be saved on to your disk. You can also use intervenable.evaluate() your interventions in terms of customized objectives.

Contributing to This Library

Please see our guidelines about how to contribute to this repository.

Pull requests, bug reports, and all other forms of contribution are welcomed and highly encouraged! :octocat:

Other Ways of Installation

Method 2: Install from the Repo

pip install git+https://github.com/stanfordnlp/pyvene.git

Method 3: Clone and Import

git clone https://github.com/stanfordnlp/pyvene.git

and in parallel folder, import to your project as,

from pyvene import pyvene
_, tokenizer, gpt2 = pyvene.create_gpt2()

Related Works in Discovering Causal Mechanism of LLMs

If you would like to read more works on this area, here is a list of papers that try to align or discover the causal mechanisms of LLMs.

Citation

Library paper is forthcoming. For now, if you use this repository, please consider to cite relevant papers:

  @article{geiger-etal-2023-DAS,
        title={Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Potts, Christopher and Icard, Thomas  and Goodman, Noah},
        year={2023},
        booktitle={arXiv}
  }

  @article{wu-etal-2023-Boundless-DAS,
        title={Interpretability at Scale: Identifying Causal Mechanisms in Alpaca}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Icard, Thomas and Potts, Christopher and Goodman, Noah},
        year={2023},
        booktitle={NeurIPS}
  }

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyvene-0.0.5.tar.gz (49.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyvene-0.0.5-py3-none-any.whl (52.4 kB view details)

Uploaded Python 3

File details

Details for the file pyvene-0.0.5.tar.gz.

File metadata

  • Download URL: pyvene-0.0.5.tar.gz
  • Upload date:
  • Size: 49.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for pyvene-0.0.5.tar.gz
Algorithm Hash digest
SHA256 191b56e1acc677da6908d143df706e24c3f5de3ca078d14783748fc7847d9e4b
MD5 577dd2bdf8a3f526bad4b1310562240f
BLAKE2b-256 8791ceaaddd977949507cc4a9c132d11330be6d0ec8d09003ef4eb6afd7e4f5a

See more details on using hashes here.

File details

Details for the file pyvene-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: pyvene-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 52.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for pyvene-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 f624b01d9e8a658da93b42ea1918c66da42f01aff0ac0d409b78e04ad825eac7
MD5 1e63da92fa6fc9c80a30c5635bdff87f
BLAKE2b-256 8565c711270a514aaf193d1a783be9b65e766ea5396752d25c8c9c2812a75dd6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page