Mechanistic interpretability toolkit for RL policies
Project description
rl_interrogate
Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate what your policy has learned — not just how well it performs.
Installation
pip install -e .
Dependencies: torch, numpy, scikit-learn, matplotlib, seaborn, gymnasium,
stable-baselines3. MuJoCo environments require the mujoco extra:
pip install -e ".[mujoco]"
Minimal Example
Load a checkpoint, run a probe, run ablation:
import torch
import numpy as np
from rl_interrogate import LinearProbe, AblationHook
# 1. Load your policy network (any torch.nn.Sequential)
policy_net = torch.load("my_policy.pt")
policy_net.eval()
# 2. Build a synthetic observation grid
obs_grid = np.random.randn(500, 28).astype(np.float32)
labels = obs_grid[:, 10] # probe for lateral position
# 3. Linear probe at layer 5
probe = LinearProbe()
probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
print(f"Layer 5 R² = {probe.score():.4f}")
# 4. Ablation: zero the probe direction, measure performance change
hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
with hook.apply(alpha=0.0):
# run your environment here — the probe direction is zeroed
pass
Experiments
The library was developed for the WakeRider paper (TMLR submission). Key experiments:
- Formation flight probe (
examples/formation_flight_probe.py): Reproduces Actor L5 R²=0.973 from the seed-42 checkpoint. - HalfCheetah ablation (
examples/halfcheetah_ablation.py): Runs ablation on a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.
API Reference
Probing
from rl_interrogate import LinearProbe, MLPProbe, LassoProbe
# Ridge regression probe (recommended)
probe = LinearProbe()
probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
r2 = probe.score()
# MLP probe (non-linear)
mlp_probe = MLPProbe()
mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
# Sparse Lasso probe
lasso = LassoProbe()
lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
r2, n_nonzero = lasso.score()
Ablation
from rl_interrogate import AblationHook
hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
with hook.apply(alpha=0.0): # alpha=0 zeros the direction
rewards = run_episodes(model, env, n=100)
PCA Utilities
from rl_interrogate import fit_pca, project_subspace
pca = fit_pca(activations, n_components=20)
acts_k = project_subspace(activations, pca, k=1) # rank-1 projection
Visualization
from rl_interrogate import plot_probe_heatmap, plot_ablation_curve
plot_probe_heatmap(activations, labels, title="Layer 5 probe")
plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])
Running Tests
pytest rl_interrogate/tests/ -v
Link to Paper
This library implements the interrogation protocol described in:
WakeRider: Emergent V-Formation Flight via Wake Exploitation Section 3.3: The rl_interrogate Library
The protocol consists of four steps:
- Linear probing — fit Ridge regression from hidden activations to a field label
- Polarity inversion — negate the sensor; verify R² drops (causal, not correlational)
- Single-direction ablation — zero the probe direction; measure performance change
- Subspace variance — greedy PCA selection to find the minimal sufficient subspace
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rl_interrogate-0.1.0.tar.gz.
File metadata
- Download URL: rl_interrogate-0.1.0.tar.gz
- Upload date:
- Size: 93.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
25509bc0d87d76e94d5568fcde6ae541c4973e41ff5cc78b7b85cd9cd13aa249
|
|
| MD5 |
21d5be866c75e606675c892568d3f600
|
|
| BLAKE2b-256 |
b8cbd05597f4c064858753eca3cd8cdbb25c871b7b84efd7137c364ccda41ddc
|
File details
Details for the file rl_interrogate-0.1.0-py3-none-any.whl.
File metadata
- Download URL: rl_interrogate-0.1.0-py3-none-any.whl
- Upload date:
- Size: 53.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
310f2c16135dfafdd9c962ee56b4392041fa1e945eecd07ba6cb6f208d8efba4
|
|
| MD5 |
ad42652edcd36b6f426f9922a7558ac6
|
|
| BLAKE2b-256 |
cc8c0e9fc8c9b5a0e25d69c7c0e0dba22df3979bfde16de25f43ffb7070b95f4
|