Skip to main content

Repo for training sparse autoencoders end-to-end

Project description

e2e_sae

This library is used to train and evaluate Sparse Autoencoders (SAEs). It handles the following training types:

  • e2e (end-to-end): Loss function includes sparsity and final model kl_divergence.
  • e2e + downstream reconstruction: Loss function includes sparsity, final model kl_divergence, and MSE at downstream layers.
  • local (i.e. vanilla SAEs): Loss function includes sparsity and MSE at the SAE layer
  • Any combination of the above.

See our paper which argues for training SAEs e2e rather than locally. All SAEs presented in the paper can be found at https://wandb.ai/sparsify/gpt2 and can be loaded using this library.

Usage

Installation

pip install e2e_sae

Train SAEs on any TransformerLens model

If you would like to track your run with Weights and Biases, place your api key and entity name in a new file called .env. An example is provided in .env.example.

Create a config file (see gpt2 configs here for examples). Then run

python e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py <path_to_config>

If using a Colab notebook, see this example.

Sample wandb sweep configs are provided in e2e_sae/scripts/train_tlens_saes/.

The library also contains scripts for training mlps and SAEs on mlps, as well as training custom transformerlens models and SAEs on these models (see here).

Load a Pre-trained SAE

You can load any pre-trained SAE (and accompanying TransformerLens model) trained using this library from Weights and Biases or locally by running

from e2e_sae import SAETransformer
model = SAETransformer.from_wandb("<entity/project/run_id>")
# or, if stored locally
model = SAETransformer.from_local_path("/path/to/checkpoint/dir") 

All runs in our paper can be loaded this way (e.g.sparsify/gpt2/tvj2owza).

This will instantiate a SAETransformer class, which contains a TransformerLens model with SAEs attached. To do a forward pass without SAEs, use the forward_raw method, to do a forward pass with SAEs, use the forward method (or simply call the SAETansformer instance).

The dictionary elements of an SAE can be accessed via SAE.dict_elements. This is will normalize the decoder elements to have norm 1.

Analysis

To reproduce all of the analysis in our paper use the scripts in e2e_sae/scripts/analysis/.

Contributing

Developer dependencies are installed with make install-dev, which will also install pre-commit hooks.

Suggested extensions and settings for VSCode are provided in .vscode/. To use the suggested settings, copy .vscode/settings-example.json to .vscode/settings.json.

There are various make commands that may be helpful

make check  # Run pre-commit checks on all files (i.e. pyright, ruff linter, and ruff formatter)
make type  # Run pyright on all files
make format  # Run ruff linter and formatter on all files
make test  # Run tests that aren't marked `slow`
make test-all  # Run all tests

This library is maintained by Dan Braun.

Join the Open Source Mechanistic Interpretability Slack to chat about this library and other projects in the space!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

e2e_sae-2.1.1.tar.gz (38.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

e2e_sae-2.1.1-py3-none-any.whl (37.4 kB view details)

Uploaded Python 3

File details

Details for the file e2e_sae-2.1.1.tar.gz.

File metadata

  • Download URL: e2e_sae-2.1.1.tar.gz
  • Upload date:
  • Size: 38.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for e2e_sae-2.1.1.tar.gz
Algorithm Hash digest
SHA256 ef0518a1b4c73712493acab5490733f4b2baa32f68855e9017fa5526b76bd9eb
MD5 86ba77692ab810c468c864f7fc430f1f
BLAKE2b-256 334a79eebf0e16886349ad4a31e633d2c2e1299b4bc2025f09919a2a88bfb9b3

See more details on using hashes here.

File details

Details for the file e2e_sae-2.1.1-py3-none-any.whl.

File metadata

  • Download URL: e2e_sae-2.1.1-py3-none-any.whl
  • Upload date:
  • Size: 37.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for e2e_sae-2.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4f6f3705456e1f45b422e23ae45e7e3dcf0038535341f012e83499647f533bc3
MD5 ab502ebf633a71c785586c7308d087e3
BLAKE2b-256 70d2cb9946749da7e04fbbdd2e1f081f55209e2301623be26f1825e4a5c915e0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page