graphpatch is a library for activation patching on PyTorch neural network models.
Project description
graphpatch
graphpatch is a library for activation patching on PyTorch
neural network models. You use it by first wrapping your model in a PatchableGraph and then running
operations in a context created by
PatchableGraph.patch():
model = GPT2LMHeadModel.from_pretrained(
"gpt2-xl",
device_map="auto",
load_in_8bit=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
inputs = tokenizer(
"The Eiffel Tower, located in", return_tensors="pt", padding=False
).to(torch.device("cuda"))
# Note that all arguments after the first are forwarded as example inputs
# to the model during compilation; use_cache and return_dict are arguments
# to GPT2LMHeadModel, not graphpatch-specific.
pg = PatchableGraph(model, **inputs, use_cache=False, return_dict=False)
# Applies two patches to the multiplication result within the activation function
# of the MLP in the 18th transformer layer. ProbePatch records the last observed value
# at the given node, while ZeroPatch zeroes out the value seen by downstream computations.
with pg.patch("transformer.h_17.mlp.act.mul_3": [probe := ProbePatch(), ZeroPatch()]):
output = pg(**inputs)
# Patches are applied in order. probe.activation holds the value prior
# to ZeroPatch zeroing it out.
print(probe.activation)
graphpatch can patch (or record) any intermediate Tensor value without manual modification of the
underlying model’s code. See full documentation here.
Requirements
graphpatch requires torch>=2 as it uses torch.compile() to build the
computational graph it uses for activation patching. As of torch>=2.1.0,
Python 3.8–3.11 are supported. torch==2.0.* do not support compilation on Python 3.11; you
will get an exception if you try to use graphpatch on such a configuration.
graphpatch automatically supports models loaded with features supplied by accelerate and
bitsandbytes. For example, you can easily use graphpatch on multiple GPU's and with quantized
inference:
model = LlamaForCausalLM.from_pretrained(
model_path, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16
)
pg = PatchableGraph(model, **example_inputs)
Installation
graphpatch is available on PyPI, and can be installed via pip:
pip install graphpatch
Optionally, you can install graphpatch with the "transformers" extra to select known compatible versions of transformers, accelerate, bitsandbytes, and miscellaneous optional requirements of these packages to quickly get started with multi-GPU and quantized inference on real-world models:
pip install graphpatch[transformers]
Demos
See the demos for some practical usage examples.
Documentation
See the full documentation on Read the Docs.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file graphpatch-0.1.0.tar.gz.
File metadata
- Download URL: graphpatch-0.1.0.tar.gz
- Upload date:
- Size: 34.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.11.2 Darwin/22.6.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e52113fa0a2de2d1482e7a3695fb385a67b1ff432067552e828bb522eeb93da1
|
|
| MD5 |
1ca6cfde4d8f8bce746bce368021e793
|
|
| BLAKE2b-256 |
8852b83e0c15379fccdc765924c7dfd545d929a3f8b239c0de94976c34b43b4a
|
File details
Details for the file graphpatch-0.1.0-py3-none-any.whl.
File metadata
- Download URL: graphpatch-0.1.0-py3-none-any.whl
- Upload date:
- Size: 37.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.11.2 Darwin/22.6.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2fe6aa35a46a62527c8cce34b6f8366992023c1ab02ac5d4e600ea424539997c
|
|
| MD5 |
204eb2d399102bd9ecb7dc823d2efd31
|
|
| BLAKE2b-256 |
fcc61b5b8db67bdb0a053db163d226768404715e2f97224f20195b1d2ac8e5a2
|