Skip to main content

graphpatch is a library for activation patching on PyTorch neural network models.

Project description

graphpatch

graphpatch is a library for activation patching on PyTorch neural network models. You use it by first wrapping your model in a PatchableGraph and then running operations in a context created by PatchableGraph.patch():

model = GPT2LMHeadModel.from_pretrained(
   "gpt2-xl",
   device_map="auto",
   load_in_8bit=True,
   torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
inputs = tokenizer(
   "The Eiffel Tower, located in", return_tensors="pt", padding=False
).to(torch.device("cuda"))
# Note that all arguments after the first are forwarded as example inputs
# to the model during compilation; use_cache and return_dict are arguments
# to GPT2LMHeadModel, not graphpatch-specific.
pg = PatchableGraph(model, **inputs, use_cache=False, return_dict=False)
# Applies two patches to the multiplication result within the activation function
# of the MLP in the 18th transformer layer. ProbePatch records the last observed value
# at the given node, while ZeroPatch zeroes out the value seen by downstream computations.
with pg.patch("transformer.h_17.mlp.act.mul_3": [probe := ProbePatch(), ZeroPatch()]):
   output = pg(**inputs)
# Patches are applied in order. probe.activation holds the value prior
# to ZeroPatch zeroing it out.
print(probe.activation)

graphpatch can patch (or record) any intermediate Tensor value without manual modification of the underlying model’s code. See full documentation here.

Requirements

graphpatch requires torch>=2 as it uses torch.compile() to build the computational graph it uses for activation patching. As of torch>=2.1.0, Python 3.8–3.11 are supported. torch==2.0.* do not support compilation on Python 3.11; you will get an exception if you try to use graphpatch on such a configuration.

graphpatch automatically supports models loaded with features supplied by accelerate and bitsandbytes. For example, you can easily use graphpatch on multiple GPU's and with quantized inference:

model = LlamaForCausalLM.from_pretrained(
   model_path, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16
)
pg = PatchableGraph(model, **example_inputs)

Installation

graphpatch is available on PyPI, and can be installed via pip:

pip install graphpatch

Optionally, you can install graphpatch with the "transformers" extra to select known compatible versions of transformers, accelerate, bitsandbytes, and miscellaneous optional requirements of these packages to quickly get started with multi-GPU and quantized inference on real-world models:

pip install graphpatch[transformers]

Demos

See the demos for some practical usage examples.

Documentation

See the full documentation on Read the Docs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graphpatch-0.1.0.tar.gz (34.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

graphpatch-0.1.0-py3-none-any.whl (37.7 kB view details)

Uploaded Python 3

File details

Details for the file graphpatch-0.1.0.tar.gz.

File metadata

  • Download URL: graphpatch-0.1.0.tar.gz
  • Upload date:
  • Size: 34.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.2 Darwin/22.6.0

File hashes

Hashes for graphpatch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e52113fa0a2de2d1482e7a3695fb385a67b1ff432067552e828bb522eeb93da1
MD5 1ca6cfde4d8f8bce746bce368021e793
BLAKE2b-256 8852b83e0c15379fccdc765924c7dfd545d929a3f8b239c0de94976c34b43b4a

See more details on using hashes here.

File details

Details for the file graphpatch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: graphpatch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 37.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.2 Darwin/22.6.0

File hashes

Hashes for graphpatch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2fe6aa35a46a62527c8cce34b6f8366992023c1ab02ac5d4e600ea424539997c
MD5 204eb2d399102bd9ecb7dc823d2efd31
BLAKE2b-256 fcc61b5b8db67bdb0a053db163d226768404715e2f97224f20195b1d2ac8e5a2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page