Skip to main content

A versatile kit for training and using linear probes on neural network activations.

Project description

Probes

A lightweight, modular library for training linear probes and steering vectors on neural network activations.

Installation

probekit is not a minimal dependency package. It pulls in heavy ML dependencies, including torch, scikit-learn, and sae-lens. Install it in an environment where large binary wheels and ML runtime deps are expected.

# PyPI install
pip install probekit

# Local editable install (from a cloned repo)
pip install -e .

Core Design (V2)

This library separates Semantics (the probe model) from Fitting (how it's learned).

1. The Models: LinearProbe and ProbeCollection

  • LinearProbe (probekit.core.probe): A container for a single probe (+ normalization stats).
  • ProbeCollection (probekit.core.collection): A container for a batch of probes.
    • to_tensor(): Stacks weights into [B, D] and biases into [B].
    • best_layer(metric): Finds the probe with the best validation accuracy.

2. The Fitters

Functional solvers in probekit.fitters take training data and return a LinearProbe (or ProbeCollection).

  • fit_logistic: Standard L2-regularized Logistic Regression.
  • fit_elastic_net: ElasticNet (L1 + L2), useful for sparse features (SAEs, Neurons).
  • fit_dim: Difference-in-Means (Class 1 Mean - Class 0 Mean).

Method choice: use fit_dim for strictly linear, overfitting-resistant separation and use fit_logistic for standard L2-regularized classification.

Batched GPU Fitters

Optimized PyTorch implementations in probekit.fitters.batch handle 3D inputs [B, N, D] efficiently on GPU:

  • fit_logistic_batch: Batched IRLS/Newton solver with auto-switch between dense Newton and memory-safe Newton-CG.
  • fit_dim_batch: Vectorized DiM with median thresholding.
  • fit_elastic_net_path: Efficiently fits a regularization path (multiple alphas) using warm-starting.

Quick Start

The high-level API supports explicit backend control (backend="torch" / backend="sklearn"), and in backend="auto" mode it prefers torch when inputs are already torch tensors.

from probekit import sae_probe, dim_probe

# 1. Single Probe (X: [N, D], y: [N])
probe = sae_probe(X_2d, y_1d)

# 2. Batched Probes (X: [B, N, D], y: [B, N] or [N])
# Uses torch batch fitters and returns a ProbeCollection
probes = sae_probe(X_3d, y)
weights, biases = probes.to_tensor() # [B, D], [B]

# 3. Inference with a trained single probe
scores = probe.predict_score(X_2d)          # raw margins/logits
pred = probe.predict(X_2d, threshold=0.0)   # binary predictions

# 4. Inference with a trained probe collection
batch_scores = probes.predict_score(X_3d)         # [B, N]
batch_pred = probes.predict(X_3d, threshold=0.0)  # [B, N]

# 5. Force backend explicitly
probe_torch = sae_probe(X_2d_torch, y_1d_torch, backend="torch")
probe_cpu = sae_probe(X_3d_numpy, y_2d_numpy, backend="sklearn")

Copyable Skill Snippet

# Probekit Quick Skill

Goal: Train and run linear probes on activations.

## Core Imports
from probekit import sae_probe, logistic_probe, dim_probe

## Train
# x: [N, D], y: [N]
probe = sae_probe(x, y)

## Inference
scores = probe.predict_score(x)
pred = probe.predict(x, threshold=0.0)

## Batched Training
# xb: [B, N, D], yb: [B, N] or broadcast-compatible
probes = sae_probe(xb, yb)
weights, biases = probes.to_tensor()
batch_scores = probes.predict_score(xb)
batch_pred = probes.predict(xb, threshold=0.0)

## Method choice
# Use dim_probe(...) for strictly linear, overfitting-resistant separation.
# Use logistic_probe(...) for standard L2-regularized classification.

Steering Vectors

You can build steering vectors for individual probes or entire collections:

from probekit import build_steering_vector, build_steering_vectors

# Single
vec = build_steering_vector(probe, sae_model, layer=10)

# Batched (Maps layers to probes)
vecs = build_steering_vectors(probe_collection, sae_model, layers=[8, 9, 10])

Structure

  • probekit/core/: LinearProbe and ProbeCollection definitions.
  • probekit/fitters/:
    • logistic.py, elastic.py, dim.py: Single-probe (CPU/sklearn) fitters.
    • batch/: Optimized GPU-batched fitters (IRLS, ISTA, DiM).
  • probekit/api.py: High-level aliases and dimension routing.
  • probekit/steering/: Tools for building steering vectors.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

probekit-0.4.0.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

probekit-0.4.0-py3-none-any.whl (36.2 kB view details)

Uploaded Python 3

File details

Details for the file probekit-0.4.0.tar.gz.

File metadata

  • Download URL: probekit-0.4.0.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for probekit-0.4.0.tar.gz
Algorithm Hash digest
SHA256 9d231db7234e234afdaff3c8cf8ec047993801b6e7232fcab7fc1965f6f8dd3c
MD5 4b32fd3f45ce8338d4cab08727e00228
BLAKE2b-256 8894afd82ef82bd89db47d1c05cc8a2694a2724364bf1385db68f92560c221ae

See more details on using hashes here.

Provenance

The following attestation bundles were made for probekit-0.4.0.tar.gz:

Publisher: publish.yml on ZuiderveldTimJ/probekit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file probekit-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: probekit-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 36.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for probekit-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7d4a1f773b6a3de8fe88ddfced052b5bfeb23773eb6e20d92d884cf6e27bb773
MD5 e8950dc0e96d19c5d9ff912ca231002f
BLAKE2b-256 f5003d7139c80351a5a7fadc333af0b8da7d1037dda308973b28b648cabcc414

See more details on using hashes here.

Provenance

The following attestation bundles were made for probekit-0.4.0-py3-none-any.whl:

Publisher: publish.yml on ZuiderveldTimJ/probekit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page