Skip to main content

Mechanistic interpretability toolkit for RL policies

Project description

rl_interrogate

Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate what your policy has learned — not just how well it performs.

Installation

pip install -e .

Dependencies: torch, numpy, scikit-learn, matplotlib, seaborn, gymnasium, stable-baselines3. MuJoCo environments require the mujoco extra:

pip install -e ".[mujoco]"

Minimal Example

Load a checkpoint, run a probe, run ablation:

import torch
import numpy as np
from rl_interrogate import LinearProbe, AblationHook

# 1. Load your policy network (any torch.nn.Sequential)
policy_net = torch.load("my_policy.pt")
policy_net.eval()

# 2. Build a synthetic observation grid
obs_grid = np.random.randn(500, 28).astype(np.float32)
labels = obs_grid[:, 10]  # probe for lateral position

# 3. Linear probe at layer 5
probe = LinearProbe()
probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
print(f"Layer 5 R² = {probe.score():.4f}")

# 4. Ablation: zero the probe direction, measure performance change
hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
with hook.apply(alpha=0.0):
    # run your environment here — the probe direction is zeroed
    pass

Experiments

The library was developed for the WakeRider paper (TMLR submission). Key experiments:

  • Formation flight probe (examples/formation_flight_probe.py): Reproduces Actor L5 R²=0.973 from the seed-42 checkpoint.
  • HalfCheetah ablation (examples/halfcheetah_ablation.py): Runs ablation on a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.

API Reference

Probing

from rl_interrogate import LinearProbe, MLPProbe, LassoProbe

# Ridge regression probe (recommended)
probe = LinearProbe()
probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
r2 = probe.score()

# MLP probe (non-linear)
mlp_probe = MLPProbe()
mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)

# Sparse Lasso probe
lasso = LassoProbe()
lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
r2, n_nonzero = lasso.score()

Ablation

from rl_interrogate import AblationHook

hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
with hook.apply(alpha=0.0):   # alpha=0 zeros the direction
    rewards = run_episodes(model, env, n=100)

PCA Utilities

from rl_interrogate import fit_pca, project_subspace

pca = fit_pca(activations, n_components=20)
acts_k = project_subspace(activations, pca, k=1)  # rank-1 projection

Visualization

from rl_interrogate import plot_probe_heatmap, plot_ablation_curve

plot_probe_heatmap(activations, labels, title="Layer 5 probe")
plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])

Running Tests

pytest rl_interrogate/tests/ -v

Link to Paper

This library implements the interrogation protocol described in:

WakeRider: Emergent V-Formation Flight via Wake Exploitation Section 3.3: The rl_interrogate Library

The protocol consists of four steps:

  1. Linear probing — fit Ridge regression from hidden activations to a field label
  2. Polarity inversion — negate the sensor; verify R² drops (causal, not correlational)
  3. Single-direction ablation — zero the probe direction; measure performance change
  4. Subspace variance — greedy PCA selection to find the minimal sufficient subspace

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rl_interrogate-0.1.0.tar.gz (93.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rl_interrogate-0.1.0-py3-none-any.whl (53.2 kB view details)

Uploaded Python 3

File details

Details for the file rl_interrogate-0.1.0.tar.gz.

File metadata

  • Download URL: rl_interrogate-0.1.0.tar.gz
  • Upload date:
  • Size: 93.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for rl_interrogate-0.1.0.tar.gz
Algorithm Hash digest
SHA256 25509bc0d87d76e94d5568fcde6ae541c4973e41ff5cc78b7b85cd9cd13aa249
MD5 21d5be866c75e606675c892568d3f600
BLAKE2b-256 b8cbd05597f4c064858753eca3cd8cdbb25c871b7b84efd7137c364ccda41ddc

See more details on using hashes here.

File details

Details for the file rl_interrogate-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: rl_interrogate-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 53.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for rl_interrogate-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 310f2c16135dfafdd9c962ee56b4392041fa1e945eecd07ba6cb6f208d8efba4
MD5 ad42652edcd36b6f426f9922a7558ac6
BLAKE2b-256 cc8c0e9fc8c9b5a0e25d69c7c0e0dba22df3979bfde16de25f43ffb7070b95f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page