A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.
Project description
traceTorch
A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.
traceTorch is bult around a single, highly compositional neuron superclass, replacing the restrictive "layer zoo" of
countless disjoint neuron types with the LeakyIntegrator. This design encapsulates a massive range of SNN dynamics:
- synaptic and recurrent filtering
- rank-based parameter scoping for scalar, per-neuron or matrix weights
- optional Exponential Moving Average (EMA) on any hidden state
- arbitrary recurrence routing to any hidden state
- flexible polarity for spike outputs: positive and/or negative
All into declarative configuration on one class. By abstracting this complexity, traceTorch provides both the robust
simplicity required for fast prototyping via familiar wrappers (LIF, RLIF, SLIF, Readout, etc.) and the
unprecedented flexibility required for real research.
Why traceTorch?
Existing SNN libraries often feel restrictive or require verbose state management. Aside from the technical features and capabilities, traceTorch follows a different philosophy, revolving around ergonomics:
- Architectural Flexibility: All existing traceTorch layers are just small wrappers of the
LeakyIntegratorsuperclass, and it's incredibly easy to add your own alterations/combinations of the features you like. - Automatic State Management: No need to manually pass hidden states through
.forward(), each layer manages its own hidden states, and calling.zero_states()on a traceTorch model recursively clears all the hidden states the entire model uses, no matter how deeply hidden they are. In a similar style,.detach_states()detaches the states from the current computation graph. - Lazy Initialization: Hidden states are initialized as
Noneand allocated dynamically based on the input shape. This completely eliminates "Batch Size Mismatch" errors during inference. - Dimension Agnostic: Whether you are working with
[Time, Batch, Features]or[Batch, Channels, Height, Width]tensors, layers just work. Change a singledimargument during layer initialization to indicate the target dimension the layer acts on. Defaults to-1for MLP,-3would work for CNN (channels are 3rd last in[B, C, H, W]or[C, H, W]). - Smooth Constraints: Parameters like decays and thresholds are constrained via Sigmoid and Softplus respectively. No hard clamping, meaning that gradients flow smoothly and accurately everywhere.
- Rank Based Parameters: Instead of messy flags like
*_is_vectororall_to_all, traceTorch uses a single*_rankinteger to define the parameter scope: 0 for a scalar (parameter is shared across the layer), 1 for a vector (per-neuron parameter), 2 for a matrix (dense all-to-all connections for recurrent layer weights).
Installation
traceTorch is a PyPI library found here. Requirements are listed in
requirements.txt.
pip install tracetorch
Quick Start
Making a traceTorch model is barely any different from PyTorch models. Here's how:
1. The "zero-boilerplate" module
Inherit from tracetorch.snn.TTModule instead of pytorch.nn.Module. This gives your model powerful recursive methods
like zero_states() and detach_states() for free, while still integrating with other PyTorch nn.Module.
import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn
class ConvSNN(snn.TTModule):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 32, 3),
# dim=-3 tells the layer that the 3rd-to-last dimension is the channel dim.
# This works for (B, C, H, W) AND unbatched (C, H, W) inputs automatically.
snn.LIF(num_neurons=32, beta=0.9, dim=-3),
nn.Flatten(),
nn.Linear(32 * 26 * 26, 10),
# Readout layer with learnable scalar decay
snn.Readout(num_neurons=10, beta=0.8, beta_rank=0)
)
def forward(self, x):
return self.net(x)
2. The Training Loop
State management is easily handled outside the forward pass. Simply call .zero_states() on the model to reset all
hidden states to None, or call .detach_states() to detach the current hidden states (used in truncated BPTT or for
online learning).
model = ConvSNN().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
loss_fn = tt.loss.soft_cross_entropy # Handles non-onehot targets gracefully
# Training Step
for x, y in loader:
x, y = x.cuda(), y.cuda()
model.train()
# Time loop
spikes = []
for step in range(num_timesteps):
# Just pass x. No state tuples to manage.
spikes.append(model(x))
# Stack output and compute loss
output = torch.stack(spikes)
loss = loss_fn(output.mean(0), y) # Rate coding example
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Crucial: Reset hidden states for the next batch
model.zero_states()
Documentation
The online documentation can be found here. It contains introductory lessons
to SNNs, the traceTorch API and layers available, as well as a couple tutorials to recreate the code found in
examples/.
Authors
Contributing
Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in on it.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tracetorch-0.5.1.tar.gz.
File metadata
- Download URL: tracetorch-0.5.1.tar.gz
- Upload date:
- Size: 20.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8be469b04060f2529eb67cfd3baa2dc58465b7db9d6ff351924abc026d4b971e
|
|
| MD5 |
f9fc2843b429a07ddae9bac2165fe032
|
|
| BLAKE2b-256 |
3650a299cc604f56c41c570db923603b4f396a423f25227d6086016dc0f05b3a
|
File details
Details for the file tracetorch-0.5.1-py3-none-any.whl.
File metadata
- Download URL: tracetorch-0.5.1-py3-none-any.whl
- Upload date:
- Size: 26.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6be2b13bdb6032af4f85df0c08a3472ae521275a2c5298c26487972de06f41d2
|
|
| MD5 |
3e99ae37b3f2f0ac9265234732b75603
|
|
| BLAKE2b-256 |
c99ee37375fd2bf17c1fd88ac5e54f24374cc55d94545c473136ff77599b3984
|