graphpatch is a library for activation patching on PyTorch neural network models.
Project description
graphpatch 0.2.3
Documentation is hosted on Read the Docs.
Overview
graphpatch is a library for activation patching (often
also referred to as “ablation”) on PyTorch neural network models. You use
it by first wrapping your model in a PatchableGraph and then running operations in a context
created by PatchableGraph.patch():
pg = PatchableGraph(model, **inputs, use_cache=False)
# Applies patches to the multiplication result within the activation function of the
# MLP in the 18th transformer layer. ProbePatch records the last observed value at the
# given node, while ZeroPatch zeroes out the value seen by downstream computations.
with pg.patch("transformer.h_17.mlp.act.mul_3": [probe := ProbePatch(), ZeroPatch()]):
output = pg(**inputs)
# Patches are applied in order. probe.activation holds the value prior
# to ZeroPatch zeroing it out.
print(probe.activation)
In contrast to other approaches, graphpatch can patch (or record) any
intermediate tensor value without manual modification of the underlying model’s code. See Working with graphpatch for
some tips on how to use the generated graphs.
Note that graphpatch activation patches are compatible with AutoGrad!
This means that, for example, you can perform optimizations over the value parameter to
AddPatch:
delta = torch.zeros(size, requires_grad=True, device="cuda")
optimizer = torch.optim.Adam([delta], lr=0.5)
for _ in range(num_steps):
with graph.patch({node_name: AddPatch(value=delta)):
logits = graph(**prompt_inputs)
loss = my_loss_function(logits)
loss.backward()
optimizer.step()
For a practical usage example, see the demo of using graphpatch to replicate ROME.
Prerequisites
The only mandatory requirements are torch>=2 and numpy>=1.17. Version 2+ of torch is required
because graphpatch leverages torch.compile(), which was introduced in 2.0.0, to extract computational graphs from models.
CUDA support is not required. numpy is required for full compile() support.
Python 3.8–3.12 are supported. Note that torch versions prior to 2.1.0 do not support compilation
on Python 3.11, and versions prior to 2.4.0 do not support compilation on Python 3.12;
you will get an exception when trying to use graphpatch with such a configuration. No version of
torch yet supports compilation on Python 3.13.
Installation
graphpatch is available on PyPI, and can be installed via pip:
pip install graphpatch
Note that you will likely want to do this in an environment that already has torch, since pip may not resolve
torch to a CUDA-enabled version by default. You don’t need to do anything special to make graphpatch compatible
with transformers, accelerate, and bitsandbytes; their presence is detected at run-time. However, for convenience,
you can install graphpatch with the “transformers” extra, which will install known compatible versions of these libraries along
with some of their optional dependencies that are otherwise mildly inconvenient to set up:
pip install graphpatch[transformers]
Model compatibility
For full functionality, graphpatch depends on being able to call torch.compile() on your
model. This currently supports a subset of possible Python operations–for example, it doesn’t support
context managers. graphpatch implements some workarounds for situations that a native
compile() can’t handle, but this coverage isn’t complete. To deal with this, graphpatch
has a graceful fallback that should be no worse of a user experience than using module hooks.
In that case, you will only be able to patch an uncompilable submodule’s inputs, outputs,
parameters, and buffers. See Notes on compilation for more discussion.
transformers integration
graphpatch is theoretically compatible with any model in Huggingface’s transformers
library, but note that there may be edge cases in specific model code that it can’t yet handle. For
example, it is not (yet!) compatible with the key-value caching implementation, so if you want full
compilation of such models you should pass use_cache=False as part of the example inputs.
graphpatch is compatible with models loaded via accelerate and with 8-bit parameters
quantized by bitsandbytes. This means that you can run graphpatch on
multiple GPU’s and/or with quantized inference very easily on models provided by transformers:
model = LlamaForCausalLM.from_pretrained(
model_path,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float16,
)
pg = PatchableGraph(model, **example_inputs, use_cache=False)
For transformers models supporting the GenerationMixin protocol, you will
also be able to use convenience functions like generate() in
combination with activation patching:
# Prevent Llama from outputting "Paris"
with pg.patch({"lm_head.output": ZeroPatch(slice=(slice(None), slice(None), 3681))}):
output_tokens = pg.generate(**inputs, max_length=20, use_cache=False)
Version compatibility
graphpatch should be compatible with all versions of optional libraries matching the minimum
version requirements, but this is a highly ambitious claim to make for a Python library. If you end
up with errors that seem related to graphpatch’s integration with these libraries, you might try
changing their versions to those listed below. This list was automatically generated as part of the
graphpatch release process. It reflects the versions used while testing graphpatch 0.2.3:
accelerate==1.0.0
bitsandbytes==0.44.1
numpy==1.24.4 (Python 3.8)
numpy==2.0.2 (Python 3.9)
numpy==2.1.1 (later Python versions)
sentencepiece==0.2.0
transformer-lens==2.4.1
transformers==4.45.2
Alternatives
Module hooks are built in to torch and can be used for activation
patching. You can even add them to existing models without modifying their code. However, this will only give you
access to module inputs and outputs; accessing or patching intermediate values still requires a manual rewrite.
TransformerLens provides the
HookPoint class, which can record and patch intermediate
activations. However, this requires manually rewriting your model’s code to wrap the values you want to make
patchable.
TorchLens records and outputs visualizations for every intermediate activation. However, it is currently unable to perform any activation patching.
nnsight offers a nice activation patching API, but is limited to module inputs and outputs.
pyvene offers fine-grained control over activation patches (for example, down to a specific attention head), and a description language/serialization format to allow specification of reproducible experiments.
Documentation index
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file graphpatch-0.2.3.tar.gz.
File metadata
- Download URL: graphpatch-0.2.3.tar.gz
- Upload date:
- Size: 59.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.10.11 Darwin/23.6.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cc966043ed32ae0bd7d321438e8d19642667be14b4975f8d8842b3bbd7d4063b
|
|
| MD5 |
1c994595717afd489c2f8c952e42d348
|
|
| BLAKE2b-256 |
e569a1e3e44ca9043441ed0c928b2a7e531daf245a1bf521dc855c7b36d5ecb8
|
File details
Details for the file graphpatch-0.2.3-py3-none-any.whl.
File metadata
- Download URL: graphpatch-0.2.3-py3-none-any.whl
- Upload date:
- Size: 64.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.10.11 Darwin/23.6.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a736f607121bc4b9348bc95126eda60d5583fcf097935bb9da95979d88c137b5
|
|
| MD5 |
e853fb475ef8b999c5cf2e69e3857957
|
|
| BLAKE2b-256 |
8956044cdd2b8d18872119e27fbc2d702926421818d1940a36c9bba4a711ead6
|