A package for extracting activations from PyTorch models and saving portable bundles
Project description
TorchLens
See, save, and steer any PyTorch model. TorchLens captures every activation and gradient -- across the forward and backward pass -- auto-visualizes the full computational graph, exposes rich per-op metadata, and lets you intervene on the network as it runs. Any architecture, even dynamic and recurrent ones.
Tested on over 700 models (image, video, audio, multimodal, language; feedforward, recurrent, transformer, GNN) — and it records every last detail of every part of your model: 180+ metadata fields per operation, and 550+ fields in total across every record type — operations, modules, parameters, buffers, gradients, and the model itself.
import torch, torchvision.models as models, torchlens as tl
model = models.alexnet(weights=None)
x = torch.randn(1, 3, 224, 224)
log = tl.trace(model, x) # one call -- full graph + all activations
print(log.summary()) # module table, op count, FLOPs
print(log['relu_1_2'].out.shape) # grab any activation by name ...
print(log['features.6'].out.shape) # ... or by module path
print(log[7].func_name) # ... or by ordinal
log.draw() # PDF of the computational graph
Quick Links
- Paper | 10-minute tutorial notebook | Facets tutorial | 5-minute gallery | 50-minute gallery
- Performance guide | AI-agent quick reference | Limitations | Migration tables
Installation
Install Graphviz first (required for graph visualizations), then TorchLens:
sudo apt install graphviz # Debian/Ubuntu; see graphviz.org for other platforms
pip install torchlens
Compatible with PyTorch 1.8.0+.
Quickstart
import torch
import torchvision.models as models
import torchlens as tl
model = models.alexnet(weights=None)
x = torch.randn(1, 3, 224, 224)
log = tl.trace(model, x)
print(log.summary())
Model: AlexNet
+-----------------------------+---------------+--------+-------+
| Layer | Output Shape | Params | Train |
+-----------------------------+---------------+--------+-------+
| input | [1,3,224,224] | 0 | - |
| features (Sequential) | [1,256,6,6] | 2.5 M | yes |
| avgpool (AdaptiveAvgPool2d) | [1,256,6,6] | 0 | - |
| classifier (Sequential) | [1,1000] | 58.6 M | yes |
| output | [1,1000] | - | - |
+-----------------------------+---------------+--------+-------+
Params: 61,100,840 unique; trainable: 61,100,840
Ops: 22 total
Edges: 23 total
Forward FLOPs: 1.4 GFLOPs MACs: 718.9 MFLOPs
Index any operation by name, module path, or ordinal:
log['relu_1_2'].out.shape # torch.Size([1, 64, 55, 55])
log['features.6'].out.shape # same op via module path
log[7].func_name # 'conv2d'
log['conv2d_3'].out.shape # short name (ordinal suffix optional)
log[-1].layer_label # 'output_1'
Visualize the graph as a PDF:
log.draw() # unrolled by default
log.draw(vis_mode='rolled') # rolled (compact for recurrent)
log.draw(vis_mode='unrolled') # every pass as a distinct node
What You Can Do
1. Flexible feature extraction
Save everything, or select exactly what you need:
# Save only relu activations
log = tl.trace(model, x, save=tl.func('relu'))
# Save all ops inside the 'encoder' submodule
log = tl.trace(model, x, save=tl.in_module('encoder'))
# Save conv2d ops that are immediately followed by a relu, keeping a 4-op lookback window
conv_before_relu = tl.func('conv2d') & tl.followed_by(tl.func('relu'))
log = tl.trace(model, x, save=conv_before_relu,
lookback=4, lookback_payload_policy='detached_raw')
# Stop capture early (can be faster than a plain forward pass)
log = tl.trace(model, x, save=tl.in_module('layer2'), halt=tl.in_module('layer2'))
# Lightweight sparse recording for tight loops -- materialize structure later
recording = tl.record(model, x, save=tl.func('relu'))
trace = recording.to_trace()
# One-line activation pull
act = tl.pluck(model, x, 'relu_1_2') # returns tensor directly
# Batch extraction across a dataset
tl.extract_dataset(model, dataset, layers=['relu_1_2', 'conv2d_3_7'],
batch_size=32, output_dir='activations/')
Performance note: With halt= and tl.record, capture can run faster
than the raw forward pass -- measured at 0.84x raw on ResNet-18 and 0.83x on
GPT-2 (HookedTransformer) at 25% depth. Full exhaustive capture runs at
roughly 14x the raw forward and amortizes on large models. See
docs/performance.md for the full benchmark table.
Save and load traces portably:
tl.save(log, 'my_trace')
loaded = tl.load('my_trace')
2. Forward AND backward pass
Capture per-op gradients with the same API:
x = torch.randn(1, 3, 224, 224, requires_grad=True)
log = tl.trace(model, x, save_grads=True)
log.log_backward(log[log.output_layers[0]].out.sum())
grad = log['relu_1_2'].grad # gradient tensor flowing through that op
print(grad.shape) # torch.Size([1, 64, 55, 55])
Narrow gradient saving to specific ops with the same selector predicates:
log = tl.trace(model, x, save_grads=tl.func('relu'))
log.log_backward(log[log.output_layers[0]].out.sum())
Backward capture is PyTorch-only. Non-torch backends expose derived leaf-level gradients through a second AD pass. See docs/backward.md.
3. Vast metadata per operation
Every operation records shape, dtype, device, timing, FLOPs, parameter info, module containment, graph distances, conditional context, RNG state, and more. The full print of any op includes all of this:
print(log['conv2d_3_7'])
Layer conv2d_3_7, operation 7/22:
Output tensor: shape=(1, 384, 13, 13), dtype=torch.float32, size=253.5 KB
tensor([[-0.0198, 0.0946, 0.1109, ...
Related Layers:
- parent layers: maxpool2d_2_6
- child layers: relu_3_8
Params: Computed from params with shape (384, 192, 3, 3), (384,); 663936 params total (2.5 MB)
Function: conv2d (grad_fn_handle: ConvolutionBackward0)
Computed inside module: features.6:1
Config: out_channels=384, in_channels=192, kernel_size=(3, 3), padding=(1, 1)
Time elapsed: 1.4 ms
Lookup keys: -17, 7, conv2d_3, conv2d_3:1, conv2d_3_7, conv2d_3_7:1, features.6, features.6:1
Every op also records the Python call stack that produced it, with file and line number:
loc = log['conv2d_3_7'].code_context[0]
print(loc.file, loc.line_number, loc.func_name)
Metadata is available as pandas DataFrames:
df = log.to_pandas() # one row per op
params_df = log.params.to_pandas()
modules_df = log.modules.to_pandas()
4. Automatic visualization
log.draw() # default: unrolled with sibling ordering
log.draw(vis_mode='rolled') # compact rolled layout
log.draw(vis_mode='unrolled') # every pass as a distinct node
Control nesting depth to zoom in on submodules:
For recurrent models, the rolled view collapses repeated structure cleanly:
class SimpleRecurrent(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(in_features=5, out_features=5)
def forward(self, x):
for r in range(4):
x = self.fc(x)
x = x + 1
x = x * 2
return x
model = SimpleRecurrent()
x = torch.randn(6, 5)
log = tl.trace(model, x)
print(log['linear_1:2'].out) # second pass of the linear layer
log.draw(vis_mode='rolled')
5. Interventions
Ablate, steer, scale, or replace activations during the forward pass:
# Zero-ablate all relu activations inline during capture
ablated = tl.trace(model, x, save=tl.func('relu'),
intervene=tl.when(tl.func('relu'), tl.zero_ablate()))
print(ablated['relu_1_2'].out.abs().max()) # tensor(0.)
# Scale relus to 50%
scaled = tl.trace(model, x, save=tl.func('relu'),
intervene=tl.when(tl.func('relu'), tl.scale(0.5)))
Available helpers: tl.zero_ablate, tl.mean_ablate, tl.resample_ablate,
tl.steer, tl.scale, tl.clamp, tl.noise, tl.project_onto,
tl.project_off, tl.swap_with, tl.splice_module.
For post-hoc DAG replay and isolated experiments, capture with
intervention_ready=True and use log.fork() + log.replay() /
log.rerun(model, x). Live hooks during rerun require capture-time selectors
(e.g. tl.func(...), tl.module(...)); finalized labels resolve via
log.find_sites(...). See docs/intervention_api.md
for the full reference.
Compare multiple runs side by side with tl.bundle:
bundle = tl.bundle({'clean': clean_log, 'patched': patched_log}, baseline='clean')
bundle.compare_at(tl.func('relu'))
Facets provide named sub-views for attention heads, LSTM outputs, and fused projections (for models with those structures):
# ViT / transformer model with attention blocks
log = tl.trace(vit_model, x)
q = log.modules['blocks.0.attn'].facets['q'] # query vectors for head 0
h_n = log.modules['lstm'].facets['h_n'] # LSTM final hidden state
See docs/facets.md for the full facets reference, including activation patching helpers, SDPA reconstruction, and TransformerLens aliases.
See docs/intervention_api.md for the full selector and helper reference.
6. Works on anything, including dynamic and recurrent models
TorchLens uses eager-mode Python-level function wrapping rather than graph tracing. This means it captures whatever actually runs, including:
- Dynamic control flow (if/else branching, loops, early exits)
- Recurrent architectures (RNNs, LSTMs, state-space models)
- Transformer variants including fused attention
- Graph neural networks
- Mixed architectures
This is the key differentiator from static-graph extractors like
torchvision.feature_extraction, which require static computational graphs
and cannot handle dynamic architectures.
Multi-backend. The same tl.trace API works across frameworks via
backend=:
| Capability | PyTorch | JAX (preview) | tinygrad (preview) | MLX (preview) | Paddle (preview) | TensorFlow (preview) |
|---|---|---|---|---|---|---|
| Forward capture + graph/metadata | yes | yes | yes | yes | yes | yes |
| Module hierarchy | torch_module |
Equinox/Flax NNX pytree_module; raw function_root |
object_module; raw function_root |
object_module; raw function_root |
object_module; raw function_root |
Keras/tf.Module object_module; raw function_root |
| Control-flow unroll | eager Python | lax.scan/cond/while_loop |
lazy UOp graph | limited | dygraph/eager Python only | eager Python control flow |
Static-label save= |
yes | yes | yes | yes | yes | yes |
Portable array .tlspec payloads |
full | forward/derived arrays | forward/derived arrays | forward/derived arrays | forward/derived arrays | forward arrays |
| Gradients | full backward graph | leaf-level + zero-tap T1 intermediate derived | leaf-level + T1 intermediate derived | leaf-level + custom-VJP-tap T1 intermediate derived | leaf-level + T1 intermediate derived | deferred |
| Interventions / halt / fastlog | yes | -- | -- | -- | -- | -- |
log = tl.trace(torch_model, x) # PyTorch (default)
log = tl.trace(jax_fn, inputs, backend='jax') # JAX preview
log = tl.trace(tg_fn, inputs, backend='tinygrad')
log = tl.trace(paddle_model, x, backend='paddle')
log = tl.trace(tf_model, x, backend='tf')
PyTorch remains the full-feature backend. Preview backends are pinned and
documented in docs/.
Gallery
TorchLens visualizes any architecture -- no matter how exotic. Below is a sample across families. The full menagerie has 650+ graphs across 44 architecture families.
Classic CNN + Vision Transformer
| GoogLeNet (inception + buffer edges) | Stable Diffusion (U-Net denoiser) | CLIP (vision + language towers) |
|---|---|---|
State-Space + Recurrence
| Mamba (selective SSM) | Recurrent Gemma (linear recurrence) | Whisper (audio encoder-decoder) |
|---|---|---|
Mixture-of-Experts + Generative
| Mixtral (sparse MoE) | Hierarchical VAE | Perceiver |
|---|---|---|
Graph Networks + Exotic
| DimeNet (molecular GNN) | CORnet-S (visual cortex, unrolled) | LLaMA (decoder-only LLM) |
|---|---|---|
Reinforcement Learning + Quantum ML + Scale
| Decision Transformer (offline RL) | Quantum ML circuit | 3,000-node graph (SFDP layout) |
|---|---|---|
Compatibility
Before filing a bug for a model-specific failure, run the runtime compatibility report:
compat = tl.compat.report(model, x)
print(compat.to_markdown())
tl.compat.report inspects the model wrapper, modules, parameter sharing,
input tensors, CUDA visibility, and common framework markers, then reports
each row as pass, known_broken, scope, or not_tested.
TorchLens is not compatible with torch.compile'd models, TorchScript,
or torch.export -- the forward pass does not run as ordinary Python, so the
wrappers cannot intercept ops. It also has specific behaviors around FSDP,
sparse tensors, meta tensors, quantization, and torch.func.vmap.
See LIMITATIONS.md for the full matrix: what fails, what works, and the recommended workaround for each context.
Tutorials and Docs
| Resource | Description |
|---|---|
| torchlens_in_10_minutes.ipynb | Core workflow: trace, index, visualize |
| facets_tutorial.ipynb | Attention heads, LSTM facets, patching |
| backward_tutorial.ipynb | Gradient capture and backward visualization |
| training_tutorial.ipynb | Training with captured activations |
| huggingface_tutorial.ipynb | HuggingFace transformer models |
| fastlog_tutorial.ipynb | High-throughput sparse recording |
| docs/intervention_api.md | Full selector and helper reference |
| docs/backward.md | Backward capture details and limitations |
| docs/facets.md | Facets, patching, and SDPA reconstruction |
| docs/performance.md | Speed knobs and benchmark numbers |
Security
Portable bundles contain a pickle file in metadata.pkl. Only load bundles
from trusted sources. Loading an untrusted bundle with tl.load() can execute
arbitrary code.
Other Packages You Should Check Out
TorchLens focuses on activation extraction, graph visualization, and intervention and intentionally omits model loading, stimulus management, and analysis pipelines. These packages cover that ground well:
- Cerbrec: interactive visualization and debugging for deep neural networks (uses TorchLens under the hood for PyTorch graph extraction)
- ThingsVision: model loading, stimulus management, and representational analysis for vision models
- Net2Brain: end-to-end pipeline for comparing DNN representations to neural data
- surgeon-pytorch: lightweight activation extraction with training-loss hooks
- deepdive: model loading and benchmarking across many model families
- torchvision feature_extraction: fast activation extraction for models with static computational graphs
- rsatoolbox: representational similarity analysis for DNN activations and brain data
Acknowledgments
The development of TorchLens benefitted greatly from discussions with Nikolaus Kriegeskorte, George Alvarez, Alfredo Canziani, Tal Golan, and the Visual Inference Lab at Columbia University. Thank you to Kale Kundert for helpful discussion and code contributions enabling PyTorch Lightning compatibility. Network visualizations are generated with Graphviz. Logo created by Nikolaus Kriegeskorte.
Citing TorchLens
To cite TorchLens, please cite this paper:
Taylor, J., Kriegeskorte, N. Extracting and visualizing hidden activations and computational graphs of PyTorch models with TorchLens. Sci Rep 13, 14375 (2023). https://doi.org/10.1038/s41598-023-40807-0
If you find TorchLens useful, a star on this repo is appreciated.
Contact
TorchLens is in active development. Questions, bug reports, and suggestions are welcome via email, Twitter, the issues page, or the discussion board.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchlens-2.26.0.tar.gz.
File metadata
- Download URL: torchlens-2.26.0.tar.gz
- Upload date:
- Size: 1.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
88e8e16d2de9a4cff934a3f63b777f37176b8d1248ea4d378326cba047ed2502
|
|
| MD5 |
8334d7388ca7e17ce83e7784539ed725
|
|
| BLAKE2b-256 |
06e035545c49af317c95895425f7269f279aa52c47dfced6362b4590756070d1
|
File details
Details for the file torchlens-2.26.0-py3-none-any.whl.
File metadata
- Download URL: torchlens-2.26.0-py3-none-any.whl
- Upload date:
- Size: 1.4 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f9b37cb006b4f5a953c64f00014e6c92d6a8434ab97c65bf5014f07c0cdaa3cd
|
|
| MD5 |
72c14d2055e8816bb73c9a094a253031
|
|
| BLAKE2b-256 |
eab86941d132511561cd50b2c348dc2d45e2a795a62b75dff75227257e4aa1ec
|