A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.
Project description
traceTorch
A strict, ergonomic, and powerful recurrent & spiking neural network library for PyTorch.
Table of contents:
Introduction
traceTorch is a unified library for recurrent networks in PyTorch, rethinking how the nets are made from the ground up. It enforces one simple rule that should have been the default all along: hidden states stay hidden. But that's not to say that they're inaccessible. On the contrary, traceTorch is designed with ergonomics at the forefront, and state management is easier than ever. Hidden states are lazily created in the forward pass, work with any target dimension, and most importantly are easy to clear, detach, and even save and load. traceTorch makes it easy for you to mix and mash recurrent layers with any other PyTorch layer, just how it should have been all this time. Take a look at the quickstart section to see how the code looks like.
The library initially started as one focused on Spiking Neural Networks (SNNs). With a slightly unorthodox, but
consistent and self-explanatory naming schema, traceTorch presents 32 distinct SNN layer types built around the Leaky
Integrator, and encapsulate a wide range of dynamics: duality (splitting positive and negative signals); recurrence;
synapse (an extra EMA accumulator before the membrane); binary, ternary, scaled ternary, or no spiking for the output at
all. The resulting 32 layers encapsulate a whopping range of possible dynamics: LI, DLI, SLI, DSLI, LIEMA,
DLIEMA, SLIEMA, DSLIEMA LIB, DLIB, SLIB, RLIB, DSLIB, DRLIB, SRLIB, DSRLIB, LIT, DLIT, SLIT,
RLIT, DSLIT, DRLIT, SRLIT, DSRLIT, LITS, DLITS, SLITS, RLITS, DSLITS, DRLITS, SRLITS, DSRLITS.
But thinking a bit outside the box, and it becomes obvious that State Space Models (SSMs) such as Mamba, are incredibly
similar to the Leaky Integrator that all the SNN layers were built around, albeit a bit more complex. Subsequently, the
philosophy was then extended to non-SNN recurrent layers: SimpleRNN, LSTM, GRU, SelectiveSSM, and more to come (
probably). The result is an opinionated but extremely ergonomic extension to PyTorch that rethinks the way that RNNs are
made: no matter the architecture, it's all just another PyTorch-esque layer that can be placed anywhere.
The main advantage and selling point of traceTorch is with how it manages hidden states. Inheriting from tt.Model
grants access to powerful recursive methods that handle all the boilerplate of state management: zero_states() and
detach_states(), save_states() and load_states(), no matter how deeply hidden they are. For some networks, some
parameters aren't used in their raw form, but instead need to be passed through an activation function of sorts, and to
skip this redundant calculation for a trained model, the module also presents TTcompile() and TTdecompile().
But if you're dissatisfied with the range of layers, then making your own ones is also incredibly easy. Inheriting from
tt.Layer (or the downstream tt.rnn.Layer or tt.snn.Layer) allows you to easily create layers that integrate
with the rest of the traceTorch ecosystem: making so that their hidden states are accessible and are created to the
proper shape; parameters can be compiled and initialization handles learnability, rank and/or a custom tensor; helper
methods to move a target dimension in and out for accessibility.
All in all, traceTorch exists to make writing, reading, debugging, and most importantly: experimenting, with recurrent networks in PyTorch to feel significantly more natural and less frustrating, while preserving (and in many cases enhancing) the expressive power needed for real models and research. traceTorch ultimately rewards users who value minimalism, composition, and long-term extensibility.
Features
As mentioned before, traceTorch currently has two main focal points for recurrent networks: RNNs which can be found in
tt.rnn and SNNs which can be found in tt.snn. Regardless of where the layer comes from though, it's inevitably a
child of tt.Layer, which makes it integrate with tt.Model and all other PyTorch modules in a layer-like
way. This means that the layers expect one input, and produce only one output. All hidden states stay hidden, internal
to the layer. And it's just one layer, not a full multi-layer model. Subsequently, the design approach changes a bit:
the model processes one timestep at a time, it's expected that the looping is done externally.
As stated earlier, the main selling point of traceTorch is in that it handles all the state management boilerplate. A
model inheriting from tt.Model means access to predominantly the zero_states() and detach_states() methods.
Both of them recursively search everywhere for where the tt.Layer layers can be hidden, and either set to None
or detach accordingly. At the time of writing, save_states() and load_states() methods are experimental, but they
allow to save and load the hidden states to .pt or .safetensors in the same way that you could save the entire
model, but as a separate file. There are also the experimental TTcompile and TTdecompile methods which optimize
specific parameters that are always passed through an activation function of sorts so that instead they're stored as the
direct values instead: to be used when a model is trained and you don't want to waste compute by re-calculating the
effective values each time.
Speaking of layers, at the time of writing, traceTorch has a total of 36. tt.rnn is a fair bit smaller and more
self-explanatory. It includes: SimpleRNN, LSTM, GRU, SelectiveSSM, with more to come (probably). The
implementations
are standard considering the "one timestep at a time" and "as a layer" rules. However, tt.snn layers are a lot more
extensive, and follow a slightly unconventional, but consistent and self-explanatory naming schema. The names are
modular and explain their role and function.
LIbase name stands forLeaky Integrator: the simplest of layer types with just one trace and decay: the membrane potential and the beta decay. No firing and no reset mechanics, this layer type is commonly known asReadout( although it's not recommended to literally have it as the final layer).~EMAsuffix is only used with theLItype of neurons, and it makes the membrane act as an exponential moving average (EMA). This isn't useful in classification where you explicitly train the model return large magnitudes of values, but it's useful in other cases where the membrane magnitude need to be stable.~Bsuffix stands forBinary, the presence of a strictly positive threshold, meaning that the layer has 2 possible outputs: a 1 or a 0.LIBis hence the official name for theLIF.~Tsuffix stands forTernary, meaning that the layer has 2 thresholds: a strictly positive and a strictly negative one, meaning that the layer has 3 possible outputs: 1, 0 or -1.~Ssuffix is only used with the~Tsuffix to create~TS, which stands forTernary Scaled, meaning that the ternary outputs are multiplicatively separately scaled based on their polarity. This is done so that the three possible outputs are truly independent when we consider the downstream layer.D~prefix stads forDual, meaning that all traces (hidden states) and their decay parameters are split into a separate positive and negative version for greater expressivity and unlocking more complex dynamics.S~prefix stands forSynaptic, meaning that before the membrane there is a separate synaptic trace with its respective alpha decay that smooth out the inputs over time via an EMA before they get integrated into the membrane.R~prefix stands forRecurrent, meaning that the layer records its own outputs into a separate trace with its own gamma decay and re-integrates it back into the membrane in the next timestep. The computation graph is made to work even with online learning.
In total, this results in 32 specially made, performant layers which easily integrate and work with other PyTorch
layers: LI, DLI, SLI, DSLI, LIEMA, DLIEMA, SLIEMA, DSLIEMA LIB, DLIB, SLIB, RLIB, DSLIB,
DRLIB, SRLIB, DSRLIB, LIT, DLIT, SLIT, RLIT, DSLIT, DRLIT, SRLIT, DSRLIT, LITS, DLITS,
SLITS, RLITS, DSLITS, DRLITS, SRLITS, DSRLITS.
Additionally, both tt.rnn and tt.snn layers handle some extra boilerplate with parameter initialization and hidden
state management, all thanks to the tt.Layer superclass and the downstream RNN and SNN variants of it (tt.rnn.Layer
and tt.snn.Layer):
- Rank-based parameter scoping for a per-layer (scalar) or per-neuron (vector) parameters, defaulting to per-neuron.
- Initialize parameters via a float value or your own desired tensor.
- Make any parameter learnable or static, automatically set to an
nn.Parameteror registered buffer accordingly. This is not applicable for some parameters, such as the linear layers insidett.rnn.GRUfor example. - Smooth parameter constraints for those that require it (sigmoid on decays and softplus on thresholds for SNN layers), meaning that gradients always flow cleanly and accurately. The respective inverse function is applied if necessary during initialization.
- Dimension movement helpers that move the tensor's dimension (the
dim=argument used during initialization) to the last dimension so that the layer is agnostic to the tensor shape and for example can work with CNNs by settingdim=-3on [..., C, H, W] data. - Property generation: parameters that require an activation function are saved in
raw_*form to account for inverse and activation functions, but work intuitively such thatlayer.betareturns the sigmoid activated value, et cetera.
Documentation
The online documentation can be found here. It is thoroughly recommended to
at least read the introduction section before proceeding as it contains some theory behind SNNs, the traceTorch ethos
and layers available as well as a brief explanation of what it is that each mechanic actually does. It also contains a
couple tutorials to recreate the code found in examples/.
Installation
traceTorch is a PyPI library found here. Requirements for the library are listed
in requirements.txt. Take note that examples found in examples/ may have their own requirements, separate from the
library requirements.
pip install tracetorch
If you don't want to install traceTorch as a library, or just want to test the examples, you should install traceTorch as an editable installation:
git clone --branch v0.16.2 https://github.com/Yegor-men/tracetorch
cd tracetorch
pip install -e .
Make sure to check the releases page for the latest (or different) version number if you want a different release.
Quickstart
traceTorch models look barely any different from PyTorch models. Keep in mind that the example code uses positional arguments for the sake of brevity, while in reality it's recommended to use keyword only arguments for the sake of clarity.
import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn, rnn
device = "cuda" if torch.cuda.is_available() else "cpu"
class SNN(tt.Model):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
snn.LIB(16, dim=-3),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1),
snn.LIB(64, dim=-3),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(7 * 7 * 64, 128),
rnn.SelectiveSSM(128, 128, 32),
nn.Linear(128, 10)
)
def forward(self, x):
return self.net(x)
model = SNN().to(device) # move the model to a device just as before
optimizer = torch.optim.AdamW(model.parameters(), 1e-3)
# TRAINING LOOP WITH DATALOADER
model.train()
for x, y in train_dataloader:
model.zero_states() # sets hidden states to None for lazy assignment
model.zero_grad()
running_loss = 0.0
for t in range(num_timesteps):
model_output = model(x[t])
loss = loss_fn(model_output, y[t])
running_loss = running_loss + loss
# optionally call model.detach_states() for online learning here
running_loss.backward()
optimizer.step()
Examples
Example code can be found in examples/. To test the code, make sure that you have the respective requirements
installed for the example, and that you've either installed traceTorch from PyPI or as an editable installation.
The current examples are unfortunately rather limited: mnist/ with monotonic.py for rate-coded classification on the
entire image and nonmonotonic.py for sequential MNIST with an adjustable kernel size. byte_lm/ is a personal project
on a byte level language model training on wikitext-103 and BirdCLEF+2026/ is a similarly experimental project on the
BirdCLEF+2026 dataset.
Authors
Contributing
Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in on it.
Roadmap
traceTorch still has a long way to go. Namely:
- Clean up
spike_fnandquant_fnfor - Fix
tt.functionalto be cleaner - Clean up
tt.plotplotting functions - Fix
TTcompileandTTdecompileto work withtt.rnn.SelectiveSSMand other layers: this means that parameter initialization must ask for an initialization function aside from just the inverse and activation functions. - Clean up and make sure that the
save_statesandload_stateswork as intended without fault - Create tests for compilation and decompilation, saving and loading
- Finish the
examples/section for example code for various examples - Make proper requirements for each example in
examples/ - Finish the
introduction/section of the docs - Do the
reference/section for the docs - Do the
tutorials/section for the docs, basing it on theexamples/ - Make docstrings
- Figure out versioning requirements for the library
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tracetorch-0.16.2.tar.gz.
File metadata
- Download URL: tracetorch-0.16.2.tar.gz
- Upload date:
- Size: 30.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
af67cf305a4f07390a4d2fb0bc496887ca0c098b336988b520925058703fad1c
|
|
| MD5 |
7f30374a5604eb11a923371be077c2d0
|
|
| BLAKE2b-256 |
7d08ad5b03556e6ae65f793c8a9b7970989cc4d0759675ba6c3d371e9705532a
|
File details
Details for the file tracetorch-0.16.2-py3-none-any.whl.
File metadata
- Download URL: tracetorch-0.16.2-py3-none-any.whl
- Upload date:
- Size: 34.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
987d6a3a86503aca113b782a3272bd8597c1ee5f4f88928127d40bd42afda3df
|
|
| MD5 |
d98701dcb12b87811e6f93435192adb7
|
|
| BLAKE2b-256 |
9916d7b3ee7d3b64f7d65c863d97661c77b40734e7e4feb0d3b586bb027676f8
|