graphpatch is a library for activation patching on PyTorch neural network models.
Project description
graphpatch
graphpatch
is a library for activation patching on PyTorch
neural network models. You use it by first wrapping your model in a PatchableGraph
and then running
operations in a context created by
PatchableGraph.patch()
:
model = GPT2LMHeadModel.from_pretrained(
"gpt2-xl",
device_map="auto",
load_in_8bit=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
inputs = tokenizer(
"The Eiffel Tower, located in", return_tensors="pt", padding=False
).to(torch.device("cuda"))
# Note that all arguments after the first are forwarded as example inputs
# to the model during compilation; use_cache and return_dict are arguments
# to GPT2LMHeadModel, not graphpatch-specific.
pg = PatchableGraph(model, **inputs, use_cache=False, return_dict=False)
# Applies two patches to the multiplication result within the activation function
# of the MLP in the 18th transformer layer. ProbePatch records the last observed value
# at the given node, while ZeroPatch zeroes out the value seen by downstream computations.
with pg.patch("transformer.h_17.mlp.act.mul_3": [probe := ProbePatch(), ZeroPatch()]):
output = pg(**inputs)
# Patches are applied in order. probe.activation holds the value prior
# to ZeroPatch zeroing it out.
print(probe.activation)
graphpatch
can patch (or record) any intermediate Tensor value without manual modification of the
underlying model’s code. See full documentation here.
Requirements
graphpatch
requires torch>=2
as it uses torch.compile()
to build the
computational graph it uses for activation patching. As of torch>=2.1.0
,
Python 3.8–3.11 are supported. torch==2.0.*
do not support compilation on Python 3.11; you
will get an exception if you try to use graphpatch
on such a configuration.
graphpatch
automatically supports models loaded with features supplied by accelerate
and
bitsandbytes
. For example, you can easily use graphpatch
on multiple GPU's and with quantized
inference:
model = LlamaForCausalLM.from_pretrained(
model_path, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16
)
pg = PatchableGraph(model, **example_inputs)
Installation
graphpatch
is available on PyPI, and can be installed via pip
:
pip install graphpatch
Optionally, you can install graphpatch
with the "transformers" extra to select known compatible versions of transformers
, accelerate
, bitsandbytes
, and miscellaneous optional requirements of these packages to quickly get started with multi-GPU and quantized inference on real-world models:
pip install graphpatch[transformers]
Demos
See the demos for some practical usage examples.
Documentation
See the full documentation on Read the Docs.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for graphpatch-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2fe6aa35a46a62527c8cce34b6f8366992023c1ab02ac5d4e600ea424539997c |
|
MD5 | 204eb2d399102bd9ecb7dc823d2efd31 |
|
BLAKE2b-256 | fcc61b5b8db67bdb0a053db163d226768404715e2f97224f20195b1d2ac8e5a2 |