LatentSAE: Training and inference for SAEs on embeddings
Project description
latent-sae
WARNING: This repo is very experimental and being actively developed for experimentation. The API for the SAE models as well as the organization of the models will likely change. The models that have been trained will also get re-trained as I prepare more training data.
Most of the code for SAE comes from https://github.com/EleutherAI/sae
#TODO: fully document data usage Currently training on https://huggingface.co/datasets/enjalot/fineweb-edu-sample-10BT-chunked-500-nomic-text-v1.5 For locally testing the code I downloaded a sample of the dataset. For training, I downloaded the whole dataset to disk in a modal volume, then processed it into sharded torch .pt files using this script: https://github.com/enjalot/fineweb-modal/blob/main/torched.py
Inference
model = Sae.load_from_hub("enjalot/sae-nomic-text-v1.5-FineWeb-edu-10BT", "64_32")
# or from disk
model = Sae.load_from_disk("models/sae_64_32.3mq7ckj7")
See notebooks/eval.ipynb for an example of how to use the model for extracting features from an embedding dataset.
Training
The main way to train (that I've gotten working) is using modal_labs infrastructure
modal run src/run_modal.py --batch-size 512 --grad-acc-steps 4 --k 64 --expansion-factor 128
I do have some initial code for training locally
python latentsae/run.py --batch-size 512 --grad-acc-steps 4 --k 64 --expansion-factor 128
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for latentsae-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c8073995c98c3024bf754959e8c246f935ad49021eaebe613e4d0a1e4394f693 |
|
MD5 | 63a0f087ca8069c8abe29f64d3f79c52 |
|
BLAKE2b-256 | 1c1c48b612b86393e96e32ba660b6aea1859e78c1271d55ee07c68f505726149 |