Skip to main content

A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.

Project description

traceTorch Banner

Documentation License PyPI

traceTorch

A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.

Introduction

traceTorch is an SNN library written from the ground up with power, flexibility and extensibility in mind. As with any other library, traceTorch presents a wide and powerful variety of distinct, commonly used neuron types that utilize sensible defaults. The naming schema and API, albeit a bit unconventional, are consistent and self-explanatory, and allow to control a massive variety of options. When necessary, parameters are bound by activation functions so that the gradients flow smoothly and are never clamped. A single dimension argument during initialization determines which of the tensor's dimensions the layer is looking at, so that the same layer can work with various tensor shapes.

traceTorch also helps with management of hidden states, no matter how deeply they are buried in the model. They are lazily initialized, and inheriting from the model superclass unlocks helper methods to find and manage them en-masse: saving and loading, zeroing and detaching. The model superclass also allows you to "compile" and "uncompile" the model so that for inference, the necessary parameters don't need to be passed through their activation function each time.

If the existing layers aren't enough, traceTorch also helps with creating your own SNN layers that comply with the rest of the traceTorch API. Inheriting from the layer superclass unlocks helper methods to initialize parameters: learnability, rank, inverse functions (if applicable).

All in all, traceTorch exists to make writing, reading, debugging, and most importantly: experimenting, with SNNs in PyTorch to feel significantly more natural and less frustrating than in existing alternatives, while preserving (and in many cases enhancing) the expressive power needed for real models and research. traceTorch ultimately rewards users who value minimalism, composition, and long-term extensibility.

Features

traceTorch follows a slightly unconventional, but consistent and self-explanatory naming schema. The names are modular and explain their role and function.

  • LI base name stands for Leaky Integrator: the simplest of layer types with just one trace and decay: the membrane potential and the beta decay. No firing and no reset mechanics, this layer type is commonly known as Readout ( although it's not recommended to literally have it as the final layer).
  • ~B suffix stands for Binary, the presence of a strictly positive threshold, meaning that the layer has 2 possible outputs: a 1 or a 0. LIB is hence the official name for the LIF.
  • ~T suffix stands for Ternary, meaning that the layer has 2 thresholds: a strictly positive and a strictly negative one, meaning that the layer has 3 possible outputs: 1, 0 or -1.
  • ~S suffix is only used with the ~T suffix to create ~TS, which stands for Ternary Scaled, meaning that the ternary outputs are multiplicatively separately scaled based on their polarity. This is done so that the three possible outputs are truly independent when we consider the downstream layer.
  • D~ prefix stads for Dual, meaning that all traces (hidden states) and their decay parameters are split into a separate positive and negative version for greater expressivity and unlocking more complex dynamics.
  • S~ prefix stands for Synaptic, meaning that before the membrane there is a separate synaptic trace with its respective alpha decay that smooth out the inputs over time via an exponential moving average (EMA) before they get integrated into the membrane.
  • R~ prefix stands for Recurrent, meaning that the layer records its own outputs into a separate trace with its own gamma decay and re-integrates it back into the membrane in the next timestep. The computation graph is made to work even with online learning.

In total, this results in 28 specially made, performant layers which easily integrate and work with other PyTorch layers: LI, DLI, SLI, DSLI, LIB, DLIB, SLIB, RLIB, DSLIB, DRLIB, SRLIB, DSRLIB, LIT, DLIT, SLIT, RLIT, DSLIT, DRLIT, SRLIT, DSRLIT, LITS, DLITS, SLITS, RLITS, DSLITS, DRLITS, SRLITS, DSRLITS.

However, layers also have a plethora of extra features:

  • Rank-based parameter scoping for per-layer (scalar) or per-neuron (vector) parameters, defaulting to per-neuron.
  • Initialize parameters via a float value or your own desired tensor.
  • Make any parameter learnable or static, automatically set to an nn.Parameter or registered buffer accordingly.
  • Single dim= argument determines the target dimension the layer focuses on: -1 for MLP, -3 for CNN, et cetera.
  • Smooth parameter constraints for those that require it (sigmoid on decays and softplus on thresholds), meaning that gradients always flow cleanly and accurately. The respective inverse function is applied if necessary during initialization.
  • All the layers with reset mechanics (~LIB, ~LIT, ~LITS) also have a spike_fn and quant_fn. The former is used to turn the membrane into a "probability" to fire, and the latter actually turns that probability into the output. spike_fn defaults to sigmoid(4x) because of the nice constraints and gradients, and quant_fn presents 3 options: round, bernoulli, probabilistic; which round, take a differentiable bernoulli sample, and take a differentiable bernoulli sample multiplied by the probability respectively, but defaults to bernoulli for stability.

traceTorch also presents the TTModel superclass, which is used for model managing. Inheriting from the TTModel class to grants access to recursive methods .zero_states() and .detach_states() to recursively respectively set the states to None or to detach; and .save_states() and .load_states() to save and load hidden states, working both with .pt and .safetensors, no matter how deeply hidden they are: PyTorch modules such as nn.Sequential or python classes and data structures; it doesn't matter. There is also .TTcompile() and TTuncompile() to compile and uncompile a tracetorch model: so that the decays and thresholds are saved as-is and don't get passed through the activation function in each timestep.

The TTLayer superclass handles all the boilerplate of creating SNN layers. Instead of wrestling with parameter registration, state management, and dimension handling, you can just inherit from TTLayer to create your own SNN layers that comply with the traceTorch ethos. It handles:

  • Automatic parameter registration: rank, learnability, value / tensor initialization and inverse functions for decays and thresholds.
  • State management for hidden states: methods to bulk zero / detach / initialize hidden states for the layer ( TTModel is for working with the layers in a model, TTLayer is for managing the states in the layer itself).
  • Dimension helpers: methods to move a tensor's dimension (the dim= used during initialization) to the last dimension so that the layer is tensor shape agnostic.
  • Property generation: parameters are saved in raw_* form to account for inverse and activation functions, but work intuitively such that layer.beta returns the sigmoid activated value, et cetera.
  • Compiling and uncompiling a model: TTcompile and TTuncompile to get rid and respectively re-add the raw_* parameters and activation functions and just use the values directly.

Documentation

The online documentation can be found here. It is thoroughly recommended to at least read the introduction section before proceeding as it contains the theory behind SNNs, the traceTorch ethos and layers available as well as a brief explanation of what it is that each mechanic actually does. It also contains a couple tutorials to recreate the code found in examples/.

Installation

traceTorch is a PyPI library found here. Requirements for the library are listed in requirements.txt. Take note that examples found in examples/ may have their own requirements, separate from the library requirements.

pip install tracetorch

If you don't want to install traceTorch as a library, or just want to test the examples, you should install traceTorch as an editable installation:

git clone --branch v0.15.0 https://github.com/Yegor-men/tracetorch
cd tracetorch
pip install -e .

Make sure to check the releases page for the latest (or different) version number if you want a different release.

Quickstart

traceTorch models look barely any different from PyTorch models:

import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn

device = "cuda" if torch.cuda.is_available() else "cpu"


class SNN(snn.TTModel):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            snn.LIB(16, dim=-3),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            snn.LIB(64, dim=-3),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 128),
            snn.LI(128),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)


model = SNN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), 1e-3)

# TRAINING LOOP WITH DATALOADER
model.train()
for x, y in train_dataloader:
    model.zero_states()  # sets hidden states to None for lazy assignment
    model.zero_grad()
    running_loss = 0.0
    for t in range(num_timesteps):
        model_output = model(x[t])
        loss = loss_fn(model_output, y[t])
        running_loss = running_loss + loss
        # optionally call model.detach_states() for online learning here
    running_loss.backward()
    optimizer.step()

Examples

Example code can be found in examples/. To test the code, make sure that you have the respective requirements installed for the example, and that you've either installed traceTorch from PyPI or as an editable installation.

Authors

Contributing

Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in on it.

Roadmap

traceTorch still has a long way to go. Namely:

  • Clean up the experimental TTcompile, TTuncompile, save_states, load_states methods for TTModel
  • Fix the LeakyIntegrator superclass and create the 28 tests
  • Finish the examples/ section for example code for various examples
  • Make proper requirements for each example in examples/
  • Finish the introduction/ section of the docs
  • Do the reference/ section for the docs
  • Do the tutorials/ section for the docs, basing it on the examples/
  • Make docstrings
  • Figure out versioning requirements for the library

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tracetorch-0.15.0.tar.gz (29.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tracetorch-0.15.0-py3-none-any.whl (33.5 kB view details)

Uploaded Python 3

File details

Details for the file tracetorch-0.15.0.tar.gz.

File metadata

  • Download URL: tracetorch-0.15.0.tar.gz
  • Upload date:
  • Size: 29.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for tracetorch-0.15.0.tar.gz
Algorithm Hash digest
SHA256 f1f0c14bf7bc18121d8ab6f208faabed070c33ef58624debc892637f1956420c
MD5 4f2744ad35c8fb0d261d5027876e52ea
BLAKE2b-256 ba3d6b191bcb5d32bd461667fcfa5de80c8578fae9f74ef4768017a4b0949291

See more details on using hashes here.

File details

Details for the file tracetorch-0.15.0-py3-none-any.whl.

File metadata

  • Download URL: tracetorch-0.15.0-py3-none-any.whl
  • Upload date:
  • Size: 33.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for tracetorch-0.15.0-py3-none-any.whl
Algorithm Hash digest
SHA256 615260b2ba8b5c30738bd31b5174ceb4921be8a593f91b4322646ddf02ddc137
MD5 a345ca406650d68fdd5acd76f3b97756
BLAKE2b-256 0bf9732fcba793598e0b8e25b593325a8506c0e39b8b0641f36eb650c622146d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page