A versatile kit for training and using linear probes on neural network activations.
Project description
Probes
A lightweight, modular library for training linear probes and steering vectors on neural network activations.
Installation
probekit is not a minimal dependency package. It pulls in heavy ML dependencies, including torch, scikit-learn, and sae-lens.
Install it in an environment where large binary wheels and ML runtime deps are expected.
# PyPI install
pip install probekit
# Local editable install (from a cloned repo)
pip install -e .
Core Design (V2)
This library separates Semantics (the probe model) from Fitting (how it's learned).
1. The Models: LinearProbe and ProbeCollection
LinearProbe(probekit.core.probe): A container for a single probe (+ normalization stats).ProbeCollection(probekit.core.collection): A container for a batch of probes.to_tensor(): Stacks weights into[B, D]and biases into[B].best_layer(metric): Finds the probe with the best validation accuracy.
2. The Fitters
Functional solvers in probekit.fitters take training data and return a LinearProbe (or ProbeCollection).
fit_logistic: Standard L2-regularized Logistic Regression.fit_elastic_net: ElasticNet (L1 + L2), useful for sparse features (SAEs, Neurons).fit_dim: Difference-in-Means (Class 1 Mean - Class 0 Mean).
Method choice: use fit_dim for strictly linear, overfitting-resistant separation and use fit_logistic for standard L2-regularized classification.
Batched GPU Fitters
Optimized PyTorch implementations in probekit.fitters.batch handle 3D inputs [B, N, D] efficiently on GPU:
fit_logistic_batch: Batched IRLS/Newton solver with auto-switch between dense Newton and memory-safe Newton-CG.fit_dim_batch: Vectorized DiM with median thresholding.fit_elastic_net_path: Efficiently fits a regularization path (multiple alphas) using warm-starting.
Quick Start
The high-level API supports explicit backend control (backend="torch" / backend="sklearn"),
and in backend="auto" mode it prefers torch when inputs are already torch tensors.
from probekit import sae_probe, dim_probe
# 1. Single Probe (X: [N, D], y: [N])
probe = sae_probe(X_2d, y_1d)
# 2. Batched Probes (X: [B, N, D], y: [B, N] or [N])
# Uses torch batch fitters and returns a ProbeCollection
probes = sae_probe(X_3d, y)
weights, biases = probes.to_tensor() # [B, D], [B]
# 3. Inference with a trained single probe
scores = probe.predict_score(X_2d) # raw margins/logits
pred = probe.predict(X_2d, threshold=0.0) # binary predictions
# 4. Inference with a trained probe collection
batch_scores = probes.predict_score(X_3d) # [B, N]
batch_pred = probes.predict(X_3d, threshold=0.0) # [B, N]
# 5. Force backend explicitly
probe_torch = sae_probe(X_2d_torch, y_1d_torch, backend="torch")
probe_cpu = sae_probe(X_3d_numpy, y_2d_numpy, backend="sklearn")
Copyable Skill Snippet
# Probekit Quick Skill
Goal: Train and run linear probes on activations.
## Core Imports
from probekit import sae_probe, logistic_probe, dim_probe
## Train
# x: [N, D], y: [N]
probe = sae_probe(x, y)
## Inference
scores = probe.predict_score(x)
pred = probe.predict(x, threshold=0.0)
## Batched Training
# xb: [B, N, D], yb: [B, N] or broadcast-compatible
probes = sae_probe(xb, yb)
weights, biases = probes.to_tensor()
batch_scores = probes.predict_score(xb)
batch_pred = probes.predict(xb, threshold=0.0)
## Method choice
# Use dim_probe(...) for strictly linear, overfitting-resistant separation.
# Use logistic_probe(...) for standard L2-regularized classification.
Steering Vectors
You can build steering vectors for individual probes or entire collections:
from probekit import build_steering_vector, build_steering_vectors
# Single
vec = build_steering_vector(probe, sae_model, layer=10)
# Batched (Maps layers to probes)
vecs = build_steering_vectors(probe_collection, sae_model, layers=[8, 9, 10])
Structure
probekit/core/:LinearProbeandProbeCollectiondefinitions.probekit/fitters/:logistic.py,elastic.py,dim.py: Single-probe (CPU/sklearn) fitters.batch/: Optimized GPU-batched fitters (IRLS, ISTA, DiM).
probekit/api.py: High-level aliases and dimension routing.probekit/steering/: Tools for building steering vectors.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file probekit-0.3.1.tar.gz.
File metadata
- Download URL: probekit-0.3.1.tar.gz
- Upload date:
- Size: 30.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cbabf950c49cf1f97139605e9c538fad36554b3e943cc99c9b6341df31a20113
|
|
| MD5 |
35de22c3e7c056a206ea910497e45425
|
|
| BLAKE2b-256 |
242b116fb9dc1b9df0c3fe59857153a29ac59029dbbdc33b0302c9a834d62bfe
|
Provenance
The following attestation bundles were made for probekit-0.3.1.tar.gz:
Publisher:
publish.yml on ZuiderveldTimJ/probekit
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
probekit-0.3.1.tar.gz -
Subject digest:
cbabf950c49cf1f97139605e9c538fad36554b3e943cc99c9b6341df31a20113 - Sigstore transparency entry: 975684921
- Sigstore integration time:
-
Permalink:
ZuiderveldTimJ/probekit@c26a77a00447ed972fc6b8d60e0cf0f090b67a31 -
Branch / Tag:
refs/tags/v0.3.1 - Owner: https://github.com/ZuiderveldTimJ
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@c26a77a00447ed972fc6b8d60e0cf0f090b67a31 -
Trigger Event:
push
-
Statement type:
File details
Details for the file probekit-0.3.1-py3-none-any.whl.
File metadata
- Download URL: probekit-0.3.1-py3-none-any.whl
- Upload date:
- Size: 35.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cce08535c19c350d156629b180690847c612a8f3b8abfce9289408fd424ed900
|
|
| MD5 |
2c8f3f808d1074ad79588706ec8c1d07
|
|
| BLAKE2b-256 |
633a337cb5ce4a4b97f96c25a33bc22becf22039a35e1e5c443bc64e1100efb8
|
Provenance
The following attestation bundles were made for probekit-0.3.1-py3-none-any.whl:
Publisher:
publish.yml on ZuiderveldTimJ/probekit
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
probekit-0.3.1-py3-none-any.whl -
Subject digest:
cce08535c19c350d156629b180690847c612a8f3b8abfce9289408fd424ed900 - Sigstore transparency entry: 975684925
- Sigstore integration time:
-
Permalink:
ZuiderveldTimJ/probekit@c26a77a00447ed972fc6b8d60e0cf0f090b67a31 -
Branch / Tag:
refs/tags/v0.3.1 - Owner: https://github.com/ZuiderveldTimJ
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@c26a77a00447ed972fc6b8d60e0cf0f090b67a31 -
Trigger Event:
push
-
Statement type: