A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.
Project description
traceTorch
A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.
traceTorch employs a two-tier architecture that balances computational efficiency with maximum flexibility. As with any other library, traceTorch presents a wide and powerful variety of distinct, commonly used neuron types utilizing sensible defaults. The naming schema is consistent (albeit a bit unconventional) and self-explanatory. Together, these layers provide control over a vast amount of neuron mechanics:
LIbase name stands forLeaky Integrator, the simplest of layer types with just one trace: the membrane potential which is the direct output, no firing; commonly known asReadout, although it's not recommended to literally make it the last layer.~Bsuffix stands forBinary, the presence of a threshold, meaning that the layer has 2 possible outputs: a 0 or 1.~Tsuffix stands forTernary, meaning that the layer has 2 thresholds: a positive and negative one, and thus 3 possible outputs: -1, 0 or 1.~Ssuffix stands forScaled, meaning that the outputs are multiplicatively scaled separately based on their polarity, used to make ternary outputs truly independent of each other.D~prefix stads forDual, meaning that all hidden states and parameters are split into a separate positive and negative version for greater expressivity and making polarity as a separate, alternate signal.S~prefix stands forSynaptic, meaning that before the membrane there is a separate synaptic trace smoothing out the inputs over time before they get integrated into the membrane.R~prefix stands forRecurrent, meaning that the layer records its own outputs into a separate trace and re-integrates it back into the membrane.
In total, this results in 32 specially made, performant layers which easily integrate and work with other PyTorch
layers: LI, LIB, LIT, LITS, DLI, DLIB, DLIT, DLITS, SLI, SLIB, SLIT, SLITS, RLI, RLIB,
RLIT, RLITS, DSLI, DLIB, DSLIT, DSLITS, DRLI, DRLIB, DRLIT, DRLITS, SRLI, SRLIB, SRLIT,
SRLITS, DSRLI, DSRLIB, DSRLIT, DSRLITS.
traceTorch also handles hidden state management in an easy-to-use way. They are set to None, and then the size is
lazily assigned based on the forward pass. Simply inherit from the TTModule class to gain access to powerful recursive
methods .zero_states() and .detach_states() to recursively respectively set the states to None or to detach, no
matter how deeply hidden they are: PyTorch modules such as nn.Sequential or python classes and data structures; it
doesn't matter. Additionally, traceTorch takes care of:
- Rank-based parameter scoping for per-layer (scalar) or per-neuron (vector) parameters, defaulting to per-neuron.
- Initialize parameters via a float value or your own desired tensor.
- Make any parameter learnable or static, automatically set to an
nn.Parameteror registered buffer accordingly. - Single
dim=argument determines the target dimension the layer focuses on:-1for MLP,-3for CNN, et cetera. - Smooth parameter constraints for those that require it (sigmoid on decays and softplus on thresholds), meaning that gradients always flow cleanly and accurately. The respective inverse function is applied if necessary during initialization.
But overarching this, traceTorch also presents one unified architecture, replacing the restrictive layer zoo of
countless disjoint neuron types with the LeakyIntegrator superclass. This design encapsulates the massive range of
possible dynamics into declarative configuration on one class, resulting in thousands of possible combinations of
features. All 32 of the layers also exist in the LeakyIntegrator form, and tests assert that the behavior of the two
versions don't differ.
Subsequently, traceTorch also presents the SetupMixin mixin class with helper methods to help with initialization of
parameters: checking for learnability, ranks, and passing through an optional inverse function; meaning that you can
easily create your own SNN layers with ease too if you find to enjoy some custom LeakyIntegrator configuration that's
missing in the pre-made layers; or make something unique altogether.
All in all, traceTorch provides both the robust simplicity required for fast prototyping via familiar wrappers and the unprecedented flexibility required for real research and models. traceTorch is intentionally designed to follow a philosophy revolving around ergonomics, usability and flexibility.
Documentation
The online documentation can be found here. It is thoroughly recommended to
at least read the introduction section before proceeding as it contains the theory behind SNNs, the traceTorch ethos and
layers available as well as a brief explanation of what it is that each mechanic actually does. It also contains a
couple tutorials to recreate the code found in examples/.
Installation
traceTorch is a PyPI library found here. Requirements for the library are listed
in requirements.txt. Take note that examples found in examples/ may have their own requirements, separate from the
library requirements.
pip install tracetorch
If you don't want to install traceTorch as a library, or just want to test the examples, you should install traceTorch as an editable installation:
git clone https://github.com/Yegor-men/tracetorch
cd tracetorch
pip install -e .
Quickstart
traceTorch models look barely any different from PyTorch models:
import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn
device = "cuda" if torch.cuda.is_available() else "cpu"
class SNN(snn.TTModule):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
snn.LIF(16, dim=-3),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1),
snn.LIF(64, dim=-3),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(7 * 7 * 64, 128),
snn.LI(128),
nn.Linear(128, 10)
)
def forward(self, x):
return self.net(x)
model = SNN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), 1e-3)
# TRAINING LOOP WITH DATALOADER
model.train()
for x, y in train_dataloader:
model.zero_states() # sets hidden states to None for lazy assignment
model.zero_grad()
running_loss = 0.0
for t in range(num_timesteps):
model_output = model(x[t])
loss = loss_fn(model_output, y[t])
running_loss = running_loss + loss
# optionally call model.detach_states() for online learning here
running_loss.backward()
optimizer.step()
Examples
Example code can be found in examples/. To test the code, make sure that you have the respective requirements
installed for the example, and that you've either installed traceTorch from PyPI or as an editable installation.
Authors
Contributing
Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in on it.
Roadmap
traceTorch still has a long way to go. Namely:
- Finish the rest of the 32 base classes
- Rewrite
LeakyIntegratorto account for duality - Finish the rest of the 32 base classes for tests in
snn.flex - Create tests to assert working order
- Finish the
examples/section for example code - Make proper requirements for each example in
examples/ - Finish the
introduction/section of the docs - Do the
reference/section for the docs - Do the
tutorials/section for the docs, basing it on theexamples/ - Make docstrings
- Figure out versioning requirements for the library
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tracetorch-0.10.0.tar.gz.
File metadata
- Download URL: tracetorch-0.10.0.tar.gz
- Upload date:
- Size: 26.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e437e1b129f1be094cf7cc57f059d2063e04ef482a18b979dbe62c5d9931a665
|
|
| MD5 |
c0f8e25676a15ac07c02721efbc94c55
|
|
| BLAKE2b-256 |
925f70cb81e109c08368e7bac57518bf54b44e5985eea40345e9bca15599650e
|
File details
Details for the file tracetorch-0.10.0-py3-none-any.whl.
File metadata
- Download URL: tracetorch-0.10.0-py3-none-any.whl
- Upload date:
- Size: 36.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c9913a21943f5bc5c57fdc4c39cb80fb2ec0654ab5895917e2587ee2ec0295d3
|
|
| MD5 |
dd43e34a4cf81de49c82709b908a8798
|
|
| BLAKE2b-256 |
2b0d07b150f4d65b036df0d482b4376d4d4faef948d53deceac61e169f126b2d
|