Skip to main content

graphpatch is a library for activation patching on PyTorch neural network models.

Project description

graphpatch 0.2.3

Documentation is hosted on Read the Docs.

Overview

graphpatch is a library for activation patching (often also referred to as “ablation”) on PyTorch neural network models. You use it by first wrapping your model in a PatchableGraph and then running operations in a context created by PatchableGraph.patch():

pg = PatchableGraph(model, **inputs, use_cache=False)
# Applies patches to the multiplication result within the activation function of the
# MLP in the 18th transformer layer. ProbePatch records the last observed value at the
# given node, while ZeroPatch zeroes out the value seen by downstream computations.
with pg.patch("transformer.h_17.mlp.act.mul_3": [probe := ProbePatch(), ZeroPatch()]):
   output = pg(**inputs)
# Patches are applied in order. probe.activation holds the value prior
# to ZeroPatch zeroing it out.
print(probe.activation)

In contrast to other approaches, graphpatch can patch (or record) any intermediate tensor value without manual modification of the underlying model’s code. See Working with graphpatch for some tips on how to use the generated graphs.

Note that graphpatch activation patches are compatible with AutoGrad! This means that, for example, you can perform optimizations over the value parameter to AddPatch:

delta = torch.zeros(size, requires_grad=True, device="cuda")
optimizer = torch.optim.Adam([delta], lr=0.5)
for _ in range(num_steps):
   with graph.patch({node_name: AddPatch(value=delta)):
      logits = graph(**prompt_inputs)
   loss = my_loss_function(logits)
   loss.backward()
   optimizer.step()

For a practical usage example, see the demo of using graphpatch to replicate ROME.

Prerequisites

The only mandatory requirements are torch>=2 and numpy>=1.17. Version 2+ of torch is required because graphpatch leverages torch.compile(), which was introduced in 2.0.0, to extract computational graphs from models. CUDA support is not required. numpy is required for full compile() support.

Python 3.8–3.12 are supported. Note that torch versions prior to 2.1.0 do not support compilation on Python 3.11, and versions prior to 2.4.0 do not support compilation on Python 3.12; you will get an exception when trying to use graphpatch with such a configuration. No version of torch yet supports compilation on Python 3.13.

Installation

graphpatch is available on PyPI, and can be installed via pip:

pip install graphpatch

Note that you will likely want to do this in an environment that already has torch, since pip may not resolve torch to a CUDA-enabled version by default. You don’t need to do anything special to make graphpatch compatible with transformers, accelerate, and bitsandbytes; their presence is detected at run-time. However, for convenience, you can install graphpatch with the “transformers” extra, which will install known compatible versions of these libraries along with some of their optional dependencies that are otherwise mildly inconvenient to set up:

pip install graphpatch[transformers]

Model compatibility

For full functionality, graphpatch depends on being able to call torch.compile() on your model. This currently supports a subset of possible Python operations–for example, it doesn’t support context managers. graphpatch implements some workarounds for situations that a native compile() can’t handle, but this coverage isn’t complete. To deal with this, graphpatch has a graceful fallback that should be no worse of a user experience than using module hooks. In that case, you will only be able to patch an uncompilable submodule’s inputs, outputs, parameters, and buffers. See Notes on compilation for more discussion.

transformers integration

graphpatch is theoretically compatible with any model in Huggingface’s transformers library, but note that there may be edge cases in specific model code that it can’t yet handle. For example, it is not (yet!) compatible with the key-value caching implementation, so if you want full compilation of such models you should pass use_cache=False as part of the example inputs.

graphpatch is compatible with models loaded via accelerate and with 8-bit parameters quantized by bitsandbytes. This means that you can run graphpatch on multiple GPU’s and/or with quantized inference very easily on models provided by transformers:

model = LlamaForCausalLM.from_pretrained(
   model_path,
   device_map="auto",
   quantization_config=BitsAndBytesConfig(load_in_8bit=True),
   torch_dtype=torch.float16,
)
pg = PatchableGraph(model, **example_inputs, use_cache=False)

For transformers models supporting the GenerationMixin protocol, you will also be able to use convenience functions like generate() in combination with activation patching:

# Prevent Llama from outputting "Paris"
with pg.patch({"lm_head.output": ZeroPatch(slice=(slice(None), slice(None), 3681))}):
   output_tokens = pg.generate(**inputs, max_length=20, use_cache=False)

Version compatibility

graphpatch should be compatible with all versions of optional libraries matching the minimum version requirements, but this is a highly ambitious claim to make for a Python library. If you end up with errors that seem related to graphpatch’s integration with these libraries, you might try changing their versions to those listed below. This list was automatically generated as part of the graphpatch release process. It reflects the versions used while testing graphpatch 0.2.3:

accelerate==1.0.0
bitsandbytes==0.44.1
numpy==1.24.4 (Python 3.8)
numpy==2.0.2 (Python 3.9)
numpy==2.1.1 (later Python versions)
sentencepiece==0.2.0
transformer-lens==2.4.1
transformers==4.45.2

Alternatives

Module hooks are built in to torch and can be used for activation patching. You can even add them to existing models without modifying their code. However, this will only give you access to module inputs and outputs; accessing or patching intermediate values still requires a manual rewrite.

TransformerLens provides the HookPoint class, which can record and patch intermediate activations. However, this requires manually rewriting your model’s code to wrap the values you want to make patchable.

TorchLens records and outputs visualizations for every intermediate activation. However, it is currently unable to perform any activation patching.

nnsight offers a nice activation patching API, but is limited to module inputs and outputs.

pyvene offers fine-grained control over activation patches (for example, down to a specific attention head), and a description language/serialization format to allow specification of reproducible experiments.

Documentation index

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphpatch-0.2.3.tar.gz (59.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

graphpatch-0.2.3-py3-none-any.whl (64.8 kB view details)

Uploaded Python 3

File details

Details for the file graphpatch-0.2.3.tar.gz.

File metadata

  • Download URL: graphpatch-0.2.3.tar.gz
  • Upload date:
  • Size: 59.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.11 Darwin/23.6.0

File hashes

Hashes for graphpatch-0.2.3.tar.gz
Algorithm Hash digest
SHA256 cc966043ed32ae0bd7d321438e8d19642667be14b4975f8d8842b3bbd7d4063b
MD5 1c994595717afd489c2f8c952e42d348
BLAKE2b-256 e569a1e3e44ca9043441ed0c928b2a7e531daf245a1bf521dc855c7b36d5ecb8

See more details on using hashes here.

File details

Details for the file graphpatch-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: graphpatch-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 64.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.11 Darwin/23.6.0

File hashes

Hashes for graphpatch-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a736f607121bc4b9348bc95126eda60d5583fcf097935bb9da95979d88c137b5
MD5 e853fb475ef8b999c5cf2e69e3857957
BLAKE2b-256 8956044cdd2b8d18872119e27fbc2d702926421818d1940a36c9bba4a711ead6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page