Skip to main content

A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.

Project description

traceTorch Banner

License PyPI

traceTorch

A strict, ergonomic, and powerful Spiking Neural Network (SNN) library for PyTorch.

traceTorch is bult around a single, highly compositional neuron superclass, replacing the restrictive "layer zoo" of countless disjoint neuron types with the LeakyIntegrator. This design encapsulates a massive range of SNN dynamics:

  • Flexible polarity for spike outputs: positive and/or negative (or none at all, thus creating a readout layer)
  • Optional synaptic and recurrent signal accumulation
  • Rank-based parameter scoping for scalar, per-neuron or matrix weights
  • Optional Exponential Moving Average (EMA) on any hidden state

All into declarative configuration on one class. By abstracting this complexity, traceTorch provides both the robust simplicity required for fast prototyping via familiar wrappers (LIF, RLIF, SLIF, Readout, etc.) and the unprecedented flexibility required for real research. In total, traceTorch presents a total of 12 easy to use layer types: LIF, BLIF, SLIF, RLIF, BSLIF, BRLIF, SRLIF, BSRLIF, Readout, SReadout, RReadout, SRReadout; with an API simple enough that you can add more with little effort.

Why traceTorch?

Existing SNN libraries often feel restrictive or require verbose state management. Aside from the technical features and capabilities, traceTorch follows a different philosophy, revolving around ergonomics:

  • Architectural Flexibility: All existing traceTorch layers are just small wrappers of the LeakyIntegrator superclass, and it's incredibly easy to add your own alterations/combinations of the features you like.
  • Automatic State Management: No need to manually pass hidden states through .forward(), each layer manages its own hidden states, and calling .zero_states() on a traceTorch model recursively clears all the hidden states the entire model uses, no matter how deeply hidden they are. In a similar style, .detach_states() detaches the states from the current computation graph.
  • Lazy Initialization: Hidden states are initialized as None and allocated dynamically based on the input shape. This completely eliminates "Batch Size Mismatch" errors during training and inference.
  • Dimension Agnostic: Whether you are working with [Time, Batch, Features] or [Batch, Channels, Height, Width] tensors, layers just work. Change a single dim argument during layer initialization to indicate the target dimension the layer acts on. Defaults to -1 for MLP, -3 would work for CNN (channels are 3rd last in [B, C, H, W] or [C, H, W]). The tensors are automatically move the target dimension to the correct index so that the layers work.
  • Smooth Constraints: Parameters like decays and thresholds are constrained via Sigmoid and Softplus respectively. No hard clamping, meaning that gradients flow smoothly and accurately everywhere.
  • Rank Based Parameters: Instead of messy flags like *_is_vector or all_to_all, traceTorch uses a single *_rank integer to define the parameter scope: 0 for a scalar (parameter is shared across the layer), 1 for a vector (per-neuron parameter), 2 for a matrix (dense all-to-all connections for recurrent layer weights).

Installation

traceTorch is a PyPI library found here. Requirements for the library are listed in requirements.txt. Take note that examples found in examples/ may have their own requirements, separate from the library requirements.

pip install tracetorch

If you want to run the example code without installing the PyPI package, or alternatively want to edit the code yourself, you should install traceTorch as an editable install.

git clone https://github.com/Yegor-men/tracetorch
cd tracetorch
pip install -e .

Quick Start

Making a traceTorch model is barely any different from PyTorch models. Here's how:

1. The "zero-boilerplate" module

Inherit from tracetorch.snn.TTModule instead of pytorch.nn.Module. This gives your model powerful recursive methods like zero_states() and detach_states() for free, while still integrating with other PyTorch nn.Module.

import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn


class ConvSNN(snn.TTModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3),
            # dim=-3 tells the layer that the 3rd-to-last dimension is the channel dim.
            # This works for (B, C, H, W) AND unbatched (C, H, W) inputs automatically.
            snn.LIF(num_neurons=32, beta=0.9, dim=-3),

            nn.Flatten(),
            nn.Linear(32 * 26 * 26, 10),

            # Readout layer with learnable scalar decay
            snn.Readout(num_neurons=10, beta=0.8, beta_rank=0)
        )

    def forward(self, x):
        return self.net(x)

2. The Training Loop

State management is easily handled outside the forward pass. Simply call .zero_states() on the model to reset all hidden states to None, or call .detach_states() to detach the current hidden states (used in truncated BPTT or for online learning).

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ConvSNN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = tt.loss.soft_cross_entropy  # Handles non-onehot probability distribution targets gracefully

# Training Step
for x, y in dataloader:
    x, y = x.to(device), y.to(device)
    model.train()

    model.zero_states()  # Crucial: Reset hidden states for the batch

    # Time loop
    spikes = []
    for step in range(num_timesteps):
        # Just pass x. No state tuples to manage.
        spikes.append(model(x))

    # Stack output and compute loss
    output = torch.stack(spikes)
    loss = loss_fn(output.mean(0), y)  # Rate coding example

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Documentation

The online documentation can be found here. It contains introductory lessons to SNNs, the traceTorch API and layers available, as well as a couple tutorials to recreate the code found in examples/.

Authors

Contributing

Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in on it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tracetorch-0.7.1.tar.gz (22.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tracetorch-0.7.1-py3-none-any.whl (28.8 kB view details)

Uploaded Python 3

File details

Details for the file tracetorch-0.7.1.tar.gz.

File metadata

  • Download URL: tracetorch-0.7.1.tar.gz
  • Upload date:
  • Size: 22.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for tracetorch-0.7.1.tar.gz
Algorithm Hash digest
SHA256 0d7bcdafffa6fdc580403c816589b51a7cf392c82218c81ce3355ea4163e386a
MD5 e753191ced51bcf4045c396528dc434b
BLAKE2b-256 4e3321c78b1f16ad891e0b6532c0d7b63d357b31d39fc30222d0ad1050816ed7

See more details on using hashes here.

File details

Details for the file tracetorch-0.7.1-py3-none-any.whl.

File metadata

  • Download URL: tracetorch-0.7.1-py3-none-any.whl
  • Upload date:
  • Size: 28.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for tracetorch-0.7.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4b5b4b8fe461bfdfc18927265427d188cbbfa4f75842add069c24966ad0d807e
MD5 b6dd542eb6590651d21f04b0b7e7f00c
BLAKE2b-256 1260c2f7c66647f6928962c1b9c4aabc595f2cbc75caff9c9bb19760cf0a2f80

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page