Skip to main content

A TensorFlow/Keras Sparse Gated Autoencoder (Gated SAE) for dictionary learning and interpretability.

Project description

🧩 gated-sae-tf

sparse gated autoencoders for TensorFlow & Keras, batteries included

Almost every open-source sparse-autoencoder (SAE) toolkit speaks PyTorch. If you live in TensorFlow/Keras, you've mostly been left to port things yourself. gated-sae-tf closes that gap: a clean, tested, pip-installable gated SAE for dictionary learning and mechanistic interpretability, with the whole training recipe and the interpretability tooling already wired up.

🧠 gated SAE · 📉 warmup→cosine LR · 🔬 sparsity + sharpness reports · 🖼️ feature galleries · ✅ 15 tests · 📦 pip install

Gated SAE overview

💭 Why this exists

I kept wanting to train sparse autoencoders on Keras models without reimplementing the gated SAE from scratch or dragging everything over to PyTorch. So I packaged the version I kept rewriting: the gated formulation from Rajamanoharan et al. (2024), the training tricks that actually make it converge (warmup→cosine LR, gradient clipping, decoder normalization, the auxiliary loss), and the little interpretability utilities you reach for thirty seconds later — sparsity stats, decoder sharpness, feature galleries — all in one importable place.

🧠 What's a gated sparse autoencoder?

A sparse autoencoder learns an overcomplete dictionary of features that reconstruct an input while keeping only a few of them active at a time. The gated variant splits two decisions that a plain ReLU SAE tangles together:

  • a gate path decides which features fire: f_gate = 1[π_gate > 0]
  • a magnitude path decides how much they fire: f_mag = relu(W_mag·x_c + b_mag)

with the magnitudes tied to the gate via a per-feature rescale, W_mag = exp(r_mag) ⊙ W_gate. The sparse code is f̃ = f_gate ⊙ f_mag and the reconstruction is x̂ = W_dec·f̃ + b_dec, decoder columns kept unit-norm.

Training minimizes:

L = L_reconstruct + λ · L_sparsity + α · L_aux

L_sparsity = Σ relu(π_gate) drives the sparsity, and L_aux reconstructs the input from relu(π_gate) through a frozen (stop_gradient) decoder so the gate path stays well-conditioned. Decoupling the two gives sharper, more monosemantic features than an L1-penalized ReLU SAE, without the shrinkage bias. ✨

📦 Install

pip install gated-sae-tf            # core (TensorFlow + Keras 3)
pip install "gated-sae-tf[viz]"     # + matplotlib for the gallery helpers

Needs Python ≥ 3.10 and TensorFlow ≥ 2.16 (where Keras 3 is the default). The package doesn't pin a TensorFlow build variant, so an existing tensorflow[and-cuda] install is respected.

🚀 Quickstart

import keras, numpy as np
from gated_sae import GatedSAE, WarmupCosineDecay, sparsity_report

(x_train, _), _ = keras.datasets.fashion_mnist.load_data()
X = x_train.reshape(-1, 784).astype("float32") / 255.0

sae = GatedSAE(input_dim=784, encoding_dim=784 * 8,   # 8x overcomplete
               lambda_sparse=1e-3, aux_weight=0.1, clip_norm=1.0)
sae(X[:2])                          # build the weights
sae.b_dec.assign(X.mean(axis=0))    # init decoder bias to the data mean

steps = (len(X) // 256) * 20
sae.compile(optimizer=keras.optimizers.Adam(
    WarmupCosineDecay(1e-3, warmup_steps=steps // 10, total_steps=steps),
    beta_1=0.0, beta_2=0.999))       # beta_1=0 per the paper

sae.fit(X, epochs=20, batch_size=256)
print(sparsity_report(sae, X))       # L0, alive/dead features, top-k share

That trains an 8× overcomplete gated SAE on Fashion-MNIST and prints the L0 mean/median, alive/dead feature counts, and the top-k activation share. A full end-to-end walkthrough notebook ships in the repository (examples/).

🧰 The API, at a glance

Symbol What it does
GatedSAE(input_dim, encoding_dim, lambda_sparse=1e-3, aux_weight=1e-2, clip_norm=1.0) The model. Custom train_step/test_step, encode/decode, set_lambda for annealing, full get_config serialization.
WarmupCosineDecay(peak_lr, warmup_steps, total_steps) Linear warmup → cosine-decay-to-zero LR schedule.
sparsity_report(model, X, batch_size=512, k=20) Dict of L0 mean/median, alive/dead counts, alive fraction, top-k activation share.
decoder_sharpness(model) (per_feature_kurtosis, mean) — higher kurtosis means sharper, more localized features.
plot_decoder_gallery(model, codes, top_n=10) Grid of the top decoder directions. Needs the [viz] extra.
plot_feature_gallery(model, codes, X, labels, class_names, ...) Decoder-direction + top-activating-images view with MONO/POLY tagging. Needs the [viz] extra.

🗺️ Roadmap

  • More SAE variants (vanilla ReLU, JumpReLU, TopK)
  • Activation-store helpers for SAEs on transformer activations
  • Pretrained dictionaries / model-hub integration

💛 Contributing

Issues and PRs are genuinely welcome — a new SAE variant, a docs fix, a bug report, all of it. The full walkthrough notebook, the contributing guide, and the extended bibliography live alongside this file in the repository (GUIDE.md and CONTRIBUTING.md).

🤖 How it was made

Built with Claude, shaped and reviewed by me. The implementation follows the gated SAE paper; training, serialization, and every interpretability utility are covered by a 15-test suite that runs in CI.

📚 References & citing

This library implements the gated SAE of Rajamanoharan et al. (2024), Improving Dictionary Learning with Gated Sparse Autoencoders, arXiv:2404.16014 — https://arxiv.org/abs/2404.16014

It sits in the interpretability lineage of Anthropic's Transformer Circuits work on dictionary learning:

Full bibliography, BibTeX, and related work (Elhage et al., Gao et al.) are in GUIDE.md in the repository.

⚖️ License

MIT © 2026 Aishwarya Natesh.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gated_sae_tf-0.1.1.tar.gz (17.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gated_sae_tf-0.1.1-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file gated_sae_tf-0.1.1.tar.gz.

File metadata

  • Download URL: gated_sae_tf-0.1.1.tar.gz
  • Upload date:
  • Size: 17.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for gated_sae_tf-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c9983ce2772b201417efb353f85db542c90aacae4d1701cc7f2864858433547d
MD5 9cee0a0d0fc3c7a7c83e7aed23897ca0
BLAKE2b-256 f89f5259a9aca8f172765703df7089a4145bc3c66b9c76b474938e15e223ccbc

See more details on using hashes here.

Provenance

The following attestation bundles were made for gated_sae_tf-0.1.1.tar.gz:

Publisher: publish.yml on aishwaryanatesh-hub/gated-sae-tf

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file gated_sae_tf-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: gated_sae_tf-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for gated_sae_tf-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 51d36338d9a4e29eee4ce4857c78a17331f4bc437ca55abb7d28e14b5ac7a149
MD5 8e44a2ddbed982e07944dccf33317754
BLAKE2b-256 bd463fec0f05d6e044d2a152c14ce607ae13fa8a7fe10ee756ef0e6d36a58814

See more details on using hashes here.

Provenance

The following attestation bundles were made for gated_sae_tf-0.1.1-py3-none-any.whl:

Publisher: publish.yml on aishwaryanatesh-hub/gated-sae-tf

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page