Repo for training sparse autoencoders end-to-end
Project description
e2e_sae
This library is used to train and evaluate Sparse Autoencoders (SAEs). It handles the following training types:
- e2e (end-to-end): Loss function includes sparsity and final model kl_divergence.
- e2e + downstream reconstruction: Loss function includes sparsity, final model kl_divergence, and MSE at downstream layers.
- local (i.e. vanilla SAEs): Loss function includes sparsity and MSE at the SAE layer
- Any combination of the above.
See our paper which argues for training SAEs e2e rather than locally. All SAEs presented in the paper can be found at https://wandb.ai/sparsify/gpt2 and can be loaded using this library.
Usage
Installation
pip install e2e_sae
Train SAEs on any TransformerLens model
If you would like to track your run with Weights and Biases, place your api key and entity name in
a new file called .env. An example is provided in .env.example.
Create a config file (see gpt2 configs here for examples). Then run
python e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py <path_to_config>
If using a Colab notebook, see this example.
Sample wandb sweep configs are provided in e2e_sae/scripts/train_tlens_saes/.
The library also contains scripts for training mlps and SAEs on mlps, as well as training custom transformerlens models and SAEs on these models (see here).
Load a Pre-trained SAE
You can load any pre-trained SAE (and accompanying TransformerLens model) trained using this library from Weights and Biases or locally by running
from e2e_sae import SAETransformer
model = SAETransformer.from_wandb("<entity/project/run_id>")
# or, if stored locally
model = SAETransformer.from_local_path("/path/to/checkpoint/dir")
All runs in our paper can be loaded this way (e.g.sparsify/gpt2/tvj2owza).
This will instantiate a SAETransformer class, which contains a TransformerLens model with SAEs
attached. To do a forward pass without SAEs, use the forward_raw method, to do a forward pass with
SAEs, use the forward method (or simply call the SAETansformer instance).
The dictionary elements of an SAE can be accessed via SAE.dict_elements. This is will normalize
the decoder elements to have norm 1.
Analysis
To reproduce all of the analysis in our
paper use the
scripts in e2e_sae/scripts/analysis/.
Contributing
Developer dependencies are installed with make install-dev, which will also install pre-commit
hooks.
Suggested extensions and settings for VSCode are provided in .vscode/. To use the suggested
settings, copy .vscode/settings-example.json to .vscode/settings.json.
There are various make commands that may be helpful
make check # Run pre-commit checks on all files (i.e. pyright, ruff linter, and ruff formatter)
make type # Run pyright on all files
make format # Run ruff linter and formatter on all files
make test # Run tests that aren't marked `slow`
make test-all # Run all tests
This library is maintained by Dan Braun.
Join the Open Source Mechanistic Interpretability Slack to chat about this library and other projects in the space!
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file e2e_sae-2.1.1.tar.gz.
File metadata
- Download URL: e2e_sae-2.1.1.tar.gz
- Upload date:
- Size: 38.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ef0518a1b4c73712493acab5490733f4b2baa32f68855e9017fa5526b76bd9eb
|
|
| MD5 |
86ba77692ab810c468c864f7fc430f1f
|
|
| BLAKE2b-256 |
334a79eebf0e16886349ad4a31e633d2c2e1299b4bc2025f09919a2a88bfb9b3
|
File details
Details for the file e2e_sae-2.1.1-py3-none-any.whl.
File metadata
- Download URL: e2e_sae-2.1.1-py3-none-any.whl
- Upload date:
- Size: 37.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4f6f3705456e1f45b422e23ae45e7e3dcf0038535341f012e83499647f533bc3
|
|
| MD5 |
ab502ebf633a71c785586c7308d087e3
|
|
| BLAKE2b-256 |
70d2cb9946749da7e04fbbdd2e1f081f55209e2301623be26f1825e4a5c915e0
|