A PyTorch library for event-based data processing
Project description
Torchevent: Spiking Neural Network Framework
Torchevent is a PyTorch-based framework for Spiking Neural Networks (SNNs). It supports training and inference for event-based datasets like NMNIST, offering custom models, loss functions, and transformations optimized for SNN workflows.
References
This project draws inspiration from the following works:
-
Paper:
TSSL-BP: Temporal-Spike-Sequence Learning via Backpropagation for Spiking Neural Networks
Proceedings of the 34th Conference on Neural Information Processing Systems (NeurIPS), 2020.
Link to Paper -
GitHub Repository:
TSSL-BP: Temporal Spike-Sequence Learning Framework
GitHub Repository
We thank the authors of these works for providing valuable insights into spiking neural network research and implementation.
Features
1. SNN Models (TSSL-BP)
- Models like
NMNISTNetandNCARSNetare specifically designed for event-based datasets. - Easily configurable for various spiking network architectures and time-step dynamics.
2. Event Data Transformations
- Transformations tailored for event-based data processing:
RandomTemporalCrop: Randomly crops events based on a given time window.TemporalCrop: Sequentially crops events within a fixed time window.ToFrameAuto: Converts events into frames with dynamic configurations.
3. Loss Functions
- Loss functions designed for SNN-specific requirements:
SpikeKernelLoss: Computes the loss using Post-Synaptic Potentials (PSP).SpikeCountLoss: Optimizes models to match desired spike counts.SpikeSoftmaxLoss: Combines spike data with softmax and cross-entropy for classification tasks.
Installation
To install torchevent manually:
git clone https://github.com/devcow85/torchevent.git
cd torchevent
pip install .
Usage
The following script demonstrates training the NMNISTNet model using the NMNIST dataset from tonic api:
import tonic
import tonic.transforms as transforms
import torch
from torch.utils.data import DataLoader
from torchevent.utils import set_seed, spike2data
from torchevent.transforms import RandomTemporalCrop, TemporalCrop
from torchevent import models, loss
# Set seed for reproducibility
set_seed(7)
# Prepare the dataset
transform = transforms.Compose([
RandomTemporalCrop(time_window=99000),
transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, n_time_bins=5),
])
train_ds = tonic.datasets.NMNIST(save_to="data", train=True, transform=transform)
val_ds = tonic.datasets.NMNIST(save_to="data", train=False, transform=transform)
# Create data loaders
train_loader = DataLoader(train_ds, shuffle=True, batch_size=32, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_ds, shuffle=False, batch_size=32, num_workers=8, pin_memory=True)
# Initialize model, optimizer, and loss function
model = models.NMNISTNet(5, 1, n_steps=5).to("cuda")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
criterion = loss.SpikeCountLoss(desired_count=4, undesired_count=1)
# Training loop
for epoch in range(3):
model.train()
for data, targets in train_loader:
data, targets = data.to("cuda", non_blocking=True), targets.to("cuda", non_blocking=True)
optimizer.zero_grad()
spikes = model(data.to(torch.float32))
spike_loss = criterion(spikes, targets)
spike_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
print(f"Epoch [{epoch+1}], Loss: {spike_loss.item():.4f}")
Expected Result
Epoch [1/3], Step [10/1875], Loss: 40.6000, Elapsed Time: 0.13s
...
Epoch [1/3] completed. Average Loss: 22.7644, Elapsed Time: 163.37s
...
Epoch [2/3], Step [1870/1875], Loss: 20.1000, Elapsed Time: 0.06s
Epoch [2/3] completed. Average Loss: 18.2996, Elapsed Time: 107.86s
...
Epoch [3/3], Step [1870/1875], Loss: 10.9000, Elapsed Time: 0.06s
Epoch [3/3] completed. Average Loss: 15.9984, Elapsed Time: 108.05s
Validation Loss: 15.2796, Accuracy: 91.01%, Elapsed Time: 5.16s
Contact
For quenstions, suggestions, or support, please contact Seokhun Jeon (seokhun.jeon@keti.re.kr)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchevent-0.0.3.tar.gz.
File metadata
- Download URL: torchevent-0.0.3.tar.gz
- Upload date:
- Size: 21.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
acfa1408e0a74ddc1ffee049ec9f5b50fb0159b895a5c96c1ca6b633f5f9915b
|
|
| MD5 |
f5f5af1e0289db6622060ad664ccd429
|
|
| BLAKE2b-256 |
93166e084df26115c4b61a800bf9dec4535f3ed4d6a6831df7c17ac336c17aeb
|
File details
Details for the file torchevent-0.0.3-py3-none-any.whl.
File metadata
- Download URL: torchevent-0.0.3-py3-none-any.whl
- Upload date:
- Size: 18.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
683cd42cb8ee6ed505ed52d11c7a15ec86a877f1fa4de98a4354bf0a8e94d436
|
|
| MD5 |
232c7cd3864ae7486e065aeef81e4b09
|
|
| BLAKE2b-256 |
c8406bd9c8d97e156415f271452b93b539ba68d3a4240664dee9ef1ba5cf9027
|