Make any model compatible with transformer_lens
Reason this release was yanked:
error
Project description
Auto_HookPoint
Auto_HookPoint is a Python library that makes it easy to integrate arbitrary pytorch models with transformer_lens. This happens via an auto_hook function that wraps your pytorch model and applies a HookPoint for every nn.Module and most nn.Parameter that are part of the model.
Features
- Works with both
nn.Module
andnn.Parameter
operations - Can be used both as a class decorator or on an already instantiated model
- Makes code cleaner
Installation
pip install Auto_HookPoint
Usage
Usage as decorator
from Auto_HookPoint import auto_hook
import torch.nn as nn
@auto_hook
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
#self.fc1_hook_point = HookPoint() NOW NOT NEEDED
def forward(self, x):
# self.fc1_hook_point(self.fc1(x)) NOW NOT NEEDED
return self.fc1(x)
model = MyModel()
print(model.hook_dict.items()) # dict_items([('hook_point', HookPoint()), ('fc1.hook_point', HookPoint())])
Wrap an instance
AutoHooked can also work with models that use nn.Parameter
, such as this AutoEncoder example:
from Auto_HookPoint import auto_hook
import torch
from torch import nn
# taken from neel nandas excellent autoencoder tutorial: https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL
class AutoEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
d_mlp = cfg["d_mlp"]
dtype = torch.float32
torch.manual_seed(cfg["seed"])
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(d_mlp, d_hidden, dtype=dtype)))
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(d_hidden, d_mlp, dtype=dtype)))
self.b_enc = nn.Parameter(
torch.zeros(d_hidden, dtype=dtype)
)
self.b_dec = nn.Parameter(
torch.zeros(d_mlp, dtype=dtype)
)
def forward(self, x):
x_cent = x - self.b_dec
acts = torch.relu(x_cent @ self.W_enc + self.b_enc)
x_reconstruct = acts @ self.W_dec + self.b_dec
return x_reconstruct
autoencoder = auto_hook(AutoEncoder({"d_mlp": 10, "dict_mult": 10, "l1_coeff": 10, "seed": 1}))
print(autoencoder.hook_dict.items())
# dict_items([('hook_point', HookPoint()), ('W_enc.hook_point', HookPoint()), ('W_dec.hook_point', HookPoint()), ('b_enc.hook_point', HookPoint()), ('b_dec.hook_point', HookPoint())])
If this was to be done manually the code would be way less clean:
class AutoEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
d_hidden = cfg['d_mlp'] * cfg['dict_mult']
d_mlp = cfg['d_mlp']
dtype = torch.float32
torch.manual_seed(cfg['seed'])
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(d_mlp, d_hidden, dtype=dtype)
)
)
self.W_enc_hook_point = HookPoint()
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(d_hidden, d_mlp, dtype=dtype)
)
)
self.W_dec_hook_point = HookPoint()
self.b_enc = nn.Parameter(
torch.zeros(d_hidden, dtype=dtype)
)
self.b_enc_hook_point = HookPoint()
self.b_dec = nn.Parameter(
torch.zeros(d_mlp, dtype=dtype)
)
self.b_dec_hook_point = HookPoint()
def forward(self, x):
x_cent = self.b_dec_hook_point(x - self.b_dec)
acts = torch.relu(self.b_enc_hook_point(self.W_enc_hook_point(x_cent @ self.W_enc) + self.b_enc))
x_reconstruct = self.b_dec_hook_point(self.W_dec_hook_point(acts @ self.W_dec) + self.b_dec)
return x_reconstruct
Note
There might be edge cases not supported for some reason, so a function 'check_auto_hook' is provided to run the model class on all internal tests. Note however that these might not always be informative, but can give hints/indications.
from Auto_HookPoint import check_auto_hook
hooked_model = auto_hook(model)
input_kwargs = {'x': torch.randn(10, 10)}
init_kwargs = {'cfg': {'d_mlp': 10, 'dict_mult': 10, 'l1_coeff': 10, 'seed': 1}}
check_auto_hook(AutoEncoder, input_kwargs, init_kwargs)
If strict is set to True, a runtime error will be raised if the tests fail; otherwise, a warning will be issued.
Note on Backward Hooks (bwd_hooks)
Some issues might occur when using backward hooks. As auto_hook hooks anything that is an instance of nn.Module, modules that return non-tensor objects will also be hooked. It is advised to only use backward hooks on hookpoints that take tensors as input.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file auto_hookpoint-0.2.0.tar.gz
.
File metadata
- Download URL: auto_hookpoint-0.2.0.tar.gz
- Upload date:
- Size: 6.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 313abf1590f58f5d37a79bb45d3864b3ddc15147ac50047efd22004dfaba4ae2 |
|
MD5 | b5b438870fde96ec112356b5bd19417e |
|
BLAKE2b-256 | 9d86cf8db784d417b334e8044f155b3b7af51a6fd82c81d6b13b6c27e65d4d1e |
File details
Details for the file auto_hookpoint-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: auto_hookpoint-0.2.0-py3-none-any.whl
- Upload date:
- Size: 7.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7b5c49543b8292025fb78c9fda88b32a97bdf184e936f6337051b430cc1be0cd |
|
MD5 | 8ce33907fe4211312a4a5c2cfdae4b78 |
|
BLAKE2b-256 | 329c8b1c4d9659e59d6ca116e757e76fec7dcdc137c641982be985b4a2b2f983 |