Skip to main content

A strict, ergonomic and "just works" Spiking Neural Network library for PyTorch.

Project description

traceTorch Banner

License PyPI

traceTorch

A strict, ergonomic and "just works" Spiking Neural Network library for PyTorch.

traceTorch is designed to eliminate the boilerplate, shape errors, and gradient issues common in SNN development. It treats spiking neurons as first-class PyTorch citizens that handle their own state, broadcasting, and constraints automatically.

Why traceTorch?

Existing SNN libraries often feel restrictive or require verbose state management. traceTorch follows a different philosophy:

  • Automatic State Management: No need to manually pass hidden states through .forward(), each layer manages its own hidden states, and calling .zero_states() on a traceTorch model recursively clears all the hidden states the entire model uses, no matter how deeply hidden they are.
  • Lazy Initialization: Hidden states are initialized as None and allocated dynamically based on the input shape. This completely eliminates "Batch Size Mismatch" errors during inference.
  • Smooth Constraints: Parameters like decay ($\beta$) and thresholds are constrained via Sigmoid and Softplus respectively. No hard clamping, meaning that gradients flow smoothly and accurately everywhere.
  • Dimension Agnostic: Whether you are working with [Time, Batch, Features] or [Batch, Channels, Height, Width] tensors, layers just work. Change a single dim argument during layer initialization to indicate the target dimension the layer acts on. Defaults to -1 for MLP, -3 would work for CNN (channels are 3rd last in [B, C, H, W] or [C, H, W]).
  • Rank Based Parameters: Instead of messy flags like *_is_vector or is_shared, traceTorch uses a single *_rank integer to define the parameter scope: 0 for a scalar (parameter is shared across the layer), 1 for a vector ( per-neuron parameter), 2 for a matrix (dense all-to-all connections for recurrent layers).

Installation

traceTorch is a PyPI library found here. Requirements are listed in requirements.txt.

pip install tracetorch

Quick Start

Making a traceTorch model is barely any different from PyTorch models. Here's how:

1. The "zero-boilerplate" module

Inherit from tracetorch.snn.TTModule instead of pytorch.nn.Module. This gives your model powerful recursive methods like zero_states() and detach_states() for free, while still integrating with other PyTorch nn.Module.

import torch
from torch import nn
import tracetorch as tt
from tracetorch import snn


class ConvSNN(snn.TTModule):
	def __init__(self):
		super().__init__()
		self.net = nn.Sequential(
			nn.Conv2d(1, 32, 3),
			# dim=-3 tells the layer that the 3rd-to-last dimension is the channel dim.
			# This works for (B, C, H, W) AND unbatched (C, H, W) inputs automatically.
			snn.LIF(num_neurons=32, beta=0.9, dim=-3),

			nn.Flatten(),
			nn.Linear(32 * 26 * 26, 10),

			# Readout layer with learnable scalar decay
			snn.Readout(num_neurons=10, beta=0.8, beta_rank=0)
		)

	def forward(self, x):
		return self.net(x)

2. The Training Loop

State management is easily handled outside the forward pass. Simply call .zero_states() on the model to reset all hidden states to None, or call .detach_states() to detach the current hidden states (used in truncated BPTT or for online learning).

model = ConvSNN().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
loss_fn = tt.loss.soft_cross_entropy  # Handles non-onehot targets gracefully

# Training Step
for x, y in loader:
	x, y = x.cuda(), y.cuda()
	model.train()

	# Time loop
	spikes = []
	for step in range(num_timesteps):
		# Just pass x. No state tuples to manage.
		spikes.append(model(x))

	# Stack output and compute loss
	output = torch.stack(spikes)
	loss = loss_fn(output.mean(0), y)  # Rate coding example

	loss.backward()
	optimizer.step()
	optimizer.zero_grad()

	# Crucial: Reset hidden states for the next batch
	model.zero_states() 

Documentation

The online documentation can be found here. It primarily focuses on how the modules work on the backend, although there are a few tutorials there that recreate the code found in examples/.

Authors

Contributing

Contributions are always welcome. Feel free to fork, submit pull requests or report issues, I will occasionally check in on it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tracetorch-0.4.0.tar.gz (17.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tracetorch-0.4.0-py3-none-any.whl (22.8 kB view details)

Uploaded Python 3

File details

Details for the file tracetorch-0.4.0.tar.gz.

File metadata

  • Download URL: tracetorch-0.4.0.tar.gz
  • Upload date:
  • Size: 17.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for tracetorch-0.4.0.tar.gz
Algorithm Hash digest
SHA256 ea68b4584043da4a01f30390fcbad4b9cac69734213a622bb0113691ec37d805
MD5 118915ae9834df63b53eb8f950fd2c0e
BLAKE2b-256 6e5c66097feb16d6f48fde1b30decdd1446d8860945901e16de85f1904a9b15d

See more details on using hashes here.

File details

Details for the file tracetorch-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: tracetorch-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 22.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for tracetorch-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 dea757f92208269c535150ce402015eca563fb02588b6df1dbbbde565f5fa268
MD5 209e512ebb2798163974d40249c4fc12
BLAKE2b-256 49fc4ab2cdb5667334c20c92aee971436e66eb230424a5a76a9910a17eb18f7f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page