Skip to main content

A PyTorch library for event-based data processing

Project description

Torchevent: Spiking Neural Network Framework

Torchevent is a PyTorch-based framework for Spiking Neural Networks (SNNs). It supports training and inference for event-based datasets like NMNIST, offering custom models, loss functions, and transformations optimized for SNN workflows.


References

This project draws inspiration from the following works:

  1. Paper:
    TSSL-BP: Temporal-Spike-Sequence Learning via Backpropagation for Spiking Neural Networks
    Proceedings of the 34th Conference on Neural Information Processing Systems (NeurIPS), 2020.
    Link to Paper

  2. GitHub Repository:
    TSSL-BP: Temporal Spike-Sequence Learning Framework
    GitHub Repository

We thank the authors of these works for providing valuable insights into spiking neural network research and implementation.


Features

1. SNN Models (TSSL-BP)

  • Models like NMNISTNet and NCARSNet are specifically designed for event-based datasets.
  • Easily configurable for various spiking network architectures and time-step dynamics.

2. Event Data Transformations

  • Transformations tailored for event-based data processing:
    • RandomTemporalCrop: Randomly crops events based on a given time window.
    • TemporalCrop: Sequentially crops events within a fixed time window.
    • ToFrameAuto: Converts events into frames with dynamic configurations.

3. Loss Functions

  • Loss functions designed for SNN-specific requirements:
    • SpikeKernelLoss: Computes the loss using Post-Synaptic Potentials (PSP).
    • SpikeCountLoss: Optimizes models to match desired spike counts.
    • SpikeSoftmaxLoss: Combines spike data with softmax and cross-entropy for classification tasks.

Installation

To install torchevent manually:

git clone https://github.com/devcow85/torchevent.git
cd torchevent
pip install .

Usage

The following script demonstrates training the NMNISTNet model using the NMNIST dataset from tonic api:

import tonic
import tonic.transforms as transforms
import torch
from torch.utils.data import DataLoader

from torchevent.utils import set_seed, spike2data
from torchevent.transforms import RandomTemporalCrop, TemporalCrop
from torchevent import models, loss

# Set seed for reproducibility
set_seed(7)

# Prepare the dataset
transform = transforms.Compose([
    RandomTemporalCrop(time_window=99000),
    transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, n_time_bins=5),
])

train_ds = tonic.datasets.NMNIST(save_to="data", train=True, transform=transform)
val_ds = tonic.datasets.NMNIST(save_to="data", train=False, transform=transform)

# Create data loaders
train_loader = DataLoader(train_ds, shuffle=True, batch_size=32, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_ds, shuffle=False, batch_size=32, num_workers=8, pin_memory=True)

# Initialize model, optimizer, and loss function
model = models.NMNISTNet(5, 1, n_steps=5).to("cuda")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
criterion = loss.SpikeCountLoss(desired_count=4, undesired_count=1)

# Training loop
for epoch in range(3):
    model.train()
    for data, targets in train_loader:
        data, targets = data.to("cuda", non_blocking=True), targets.to("cuda", non_blocking=True)
        optimizer.zero_grad()
        spikes = model(data.to(torch.float32))
        spike_loss = criterion(spikes, targets)
        spike_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        print(f"Epoch [{epoch+1}], Loss: {spike_loss.item():.4f}")

Expected Result

Epoch [1/3], Step [10/1875], Loss: 40.6000, Elapsed Time: 0.13s
...
Epoch [1/3] completed. Average Loss: 22.7644, Elapsed Time: 163.37s
...
Epoch [2/3], Step [1870/1875], Loss: 20.1000, Elapsed Time: 0.06s
Epoch [2/3] completed. Average Loss: 18.2996, Elapsed Time: 107.86s
...
Epoch [3/3], Step [1870/1875], Loss: 10.9000, Elapsed Time: 0.06s
Epoch [3/3] completed. Average Loss: 15.9984, Elapsed Time: 108.05s
Validation Loss: 15.2796, Accuracy: 91.01%, Elapsed Time: 5.16s

Contact

For quenstions, suggestions, or support, please contact Seokhun Jeon (seokhun.jeon@keti.re.kr)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchevent-0.0.3.tar.gz (21.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchevent-0.0.3-py3-none-any.whl (18.2 kB view details)

Uploaded Python 3

File details

Details for the file torchevent-0.0.3.tar.gz.

File metadata

  • Download URL: torchevent-0.0.3.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.12

File hashes

Hashes for torchevent-0.0.3.tar.gz
Algorithm Hash digest
SHA256 acfa1408e0a74ddc1ffee049ec9f5b50fb0159b895a5c96c1ca6b633f5f9915b
MD5 f5f5af1e0289db6622060ad664ccd429
BLAKE2b-256 93166e084df26115c4b61a800bf9dec4535f3ed4d6a6831df7c17ac336c17aeb

See more details on using hashes here.

File details

Details for the file torchevent-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: torchevent-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 18.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.12

File hashes

Hashes for torchevent-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 683cd42cb8ee6ed505ed52d11c7a15ec86a877f1fa4de98a4354bf0a8e94d436
MD5 232c7cd3864ae7486e065aeef81e4b09
BLAKE2b-256 c8406bd9c8d97e156415f271452b93b539ba68d3a4240664dee9ef1ba5cf9027

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page