Skip to main content

make any model compatible with transformer_lens

Reason this release was yanked:

error

Project description

Automatic_Hook

AutoHooked is a Python library that makes it possible to use arbitrary models in transformer_lens. This happens via an auto_hook function that wraps your pytorch model and applies hookpoint for every major

Features

  • Works with both nn.Module and nn.Parameter operations
  • Can be use both as a class decorator or on an already instantiated model

Installation

pip install Automatic_Hook

Usage

###Usage as decorator

from Automatic_Hook import auto_hook
import torch.nn as nn

@auto_hook
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        #self.fc1_hook_point = HookPoint() NOW NOT NEEDED

    def forward(self, x):
        # self.fc1_hook_point(self.fc1(x)) NOW NOT NEEDED
        return self.fc1(x)

model = MyModel()
print(model.hook_dict.items())  # dict_items([('hook_point', HookPoint()), ('fc1.hook_point', HookPoint())])

Wrap an instance

AutoHooked can also work with models that use nn.Parameter, such as this AutoEncoder example:

from Automatic_Hook import auto_hook
import torch
from torch import nn

# taken from neel nandas excellent autoencoder tutorial: https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
        d_mlp = cfg["d_mlp"]
        dtype = torch.float32
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(d_mlp, d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(d_hidden, d_mlp, dtype=dtype)))
        self.b_enc = nn.Parameter(
            torch.zeros(d_hidden, dtype=dtype)
        )
        self.b_dec = nn.Parameter(
            torch.zeros(d_mlp, dtype=dtype)
        )

    def forward(self, x):
        x_cent = x - self.b_dec
        acts = torch.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        return x_reconstruct

autoencoder = auto_hook(AutoEncoder({"d_mlp": 10, "dict_mult": 10, "l1_coeff": 10, "seed": 1}))
print(autoencoder.hook_dict.items())
# dict_items([('hook_point', HookPoint()), ('W_enc.hook_point', HookPoint()), ('W_dec.hook_point', HookPoint()), ('b_enc.hook_point', HookPoint()), ('b_dec.hook_point', HookPoint())])

If this was to be done manually the code would be way less clean:

class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg['d_mlp'] * cfg['dict_mult']
        d_mlp = cfg['d_mlp']
        dtype = torch.float32
        torch.manual_seed(cfg['seed'])
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(d_mlp, d_hidden, dtype=dtype)
            )
        )
        self.W_enc_hook_point = HookPoint()
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(d_hidden, d_mlp, dtype=dtype)
            )
        )
        self.W_dec_hook_point = HookPoint()
        self.b_enc = nn.Parameter(
            torch.zeros(d_hidden, dtype=dtype)
        )
        self.b_enc_hook_point = HookPoint()
        self.b_dec = nn.Parameter(
            torch.zeros(d_mlp, dtype=dtype)
        )
        self.b_dec_hook_point = HookPoint()

    def forward(self, x):
        x_cent = self.b_dec_hook_point(x - self.b_dec)
        acts = torch.relu(self.b_enc_hook_point(self.W_enc_hook_point(x_cent @ self.W_enc) + self.b_enc))
        x_reconstruct = self.b_dec_hook_point(self.W_dec_hook_point(acts @ self.W_dec) + self.b_dec)
        return x_reconstruct

Note

There might be edge cases not supported for some weird reason so a function 'check_auto_hook' is provided to run the model class on all internal tests.

Note however that these might not always be informative, but can give hints/indications.

from Automatic_Hook import check_auto_hook
hooked_model = auto_hook(model)
input_kwargs = {'x': torch.randn(10, 10)}
init_kwargs = {'cfg': {'d_mlp': 10, 'dict_mult': 10, 'l1_coeff': 10, 'seed': 1}}
check_auto_hook(AutoEncoder, input_kwargs, init_kwargs)

if strict is set to True a runtime error will be raised if the tests fail else a warning.

Backward(bwd) Hook

Some trouble might occur this is specifcally when a model or its inner-components returns a non-tensor object which is then passed to a hook. I am working on how to resolve this. However this would still work if those hooks are just disabled.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

auto_hookpoint-0.1.0.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

auto_hookpoint-0.1.0-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file auto_hookpoint-0.1.0.tar.gz.

File metadata

  • Download URL: auto_hookpoint-0.1.0.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for auto_hookpoint-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a229f0b0e59554069117d1d6de74028b5f846bad93346e649ede5246dab64361
MD5 cb1a1d5280e88ea557b90c574b46120a
BLAKE2b-256 dd7fa0768ef7587e34da099ddf7ec29c3d25dacff77d7d52953816bfb97baf5e

See more details on using hashes here.

File details

Details for the file auto_hookpoint-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for auto_hookpoint-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d33e91e385ca7254b803751868eb5f0ea1d0b2df81807f0f2d805fb3240566b7
MD5 0cd931a658c20e02239034f984c4fc04
BLAKE2b-256 eedfb3f26f5fca6b9c31bb15ab0a0b9322a63a64e116db1d5c0f43136175b0bc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page