Skip to main content

A deep learning framework for SNNs built on PyTorch.

Project description

SpikingJelly

GitHub last commit Documentation Status PyPI PyPI - Python Version repo size

English | 中文

demo

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.

The documentation of SpikingJelly is written in both English and Chinese: https://spikingjelly.readthedocs.io.

Installation

Note that SpikingJelly is based on PyTorch. Please make sure that you have installed PyTorch before you install SpikingJelly.

The odd version number is the developing version, which is updated with GitHub/OpenI repository. The even version number is the stable version and available at PyPI.

Install the last stable version (0.0.0.0.6) from PyPI:

pip install spikingjelly

Install the latest developing version from the source codes:

From GitHub:

git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install

From OpenI

git clone https://git.openi.org.cn/OpenI/spikingjelly.git
cd spikingjelly
python setup.py install

Build SNN In An Unprecedented Simple Way

SpikingJelly is user-friendly. Building SNN with SpikingJelly is as simple as building ANN in PyTorch:

class Net(nn.Module):
    def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0):
        super().__init__()
        # Network structure, a simple two-layer fully connected network, each layer is followed by LIF neurons
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 14 * 14, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
            nn.Linear(14 * 14, 10, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset)
        )

    def forward(self, x):
        return self.fc(x)

This simple network with a Poisson encoder can achieve 92% accuracy on MNIST test dataset. Read the tutorial of clock driven for more details. You can also run this code in Python terminal for training on classifying MNIST:

>>> import spikingjelly.clock_driven.examples.lif_fc_mnist as lif_fc_mnist
>>> lif_fc_mnist.main()

Read spikingjelly.clock_driven.examples to explore more advanced networks!

Fast And Handy ANN-SNN Conversion

SpikingJelly implements a relatively general ANN-SNN Conversion interface. Users can realize the conversion through PyTorch or ONNX packages. What's more, users can customize the conversion module to add to the conversion.

class ANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.BatchNorm2d(32, eps=1e-3),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),

            nn.Conv2d(32, 32, 3, 1),
            nn.BatchNorm2d(32, eps=1e-3),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),

            nn.Conv2d(32, 32, 3, 1),
            nn.BatchNorm2d(32, eps=1e-3),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),

            nn.Flatten(),
            nn.Linear(32, 10),
            nn.ReLU()
        )

    def forward(self,x):
        x = self.network(x)
        return x

This simple network with analog encoding can achieve 98.51% accuracy after converiosn on MNIST test dataset. Read the tutorial of ann2snn for more details. You can also run this code in Python terminal for training on classifying MNIST using converted model:

>>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist
>>> cnn_mnist.main()

CUDA-Enhanced Neuron

SpikingJelly provides two backends for multi-step neurons (read Tutorials for more details). You can use the user-friendly torch backend for easily codding and debugging, and use cupy backend for faster training speed.

The followed figure compares execution time of two backends of Multi-Step LIF neurons:

exe_time_fb

To use the cupy backend, please install CuPy. Note that the cupy backend only supports GPU, while the torch backend supports both CPU and GPU.

Device Supports

  • Nvidia GPU
  • CPU

As simple as using PyTorch.

>>> net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau))
>>> net = net.to(device) # Can be CPU or CUDA devices

Neuromorphic Datasets Supports

SpikingJelly includes the following neuromorphic datasets:

Dataset Source
ASL-DVS Graph-based Object Classification for Neuromorphic Vision Sensing
CIFAR10-DVS CIFAR10-DVS: An Event-Stream Dataset for Object Classification
DVS128 Gesture A Low Power, Fully Event-Based Gesture Recognition System
N-Caltech101 Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades
N-MNIST Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades

Users can use both the origin events data and frames data integrated by SpikingJelly:

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')

More datasets will be included in the future.

If some datasets' download link are not available for some users, the users can download from the OpenI mirror:

https://git.openi.org.cn/OpenI/spikingjelly/datasets?type=0

All datasets saved in the OpenI mirror are allowable by their licence or authors' agreement.

Tutorials

SpikingJelly provides elaborate tutorials. Here are some of tutorials:

Figure Tutorial
t0 Neurons
t2 Encoder
t3 Use single-layer fully connected SNN to identify MNIST
t4 Use convolutional SNN to identify Fashion-MNIST
t5 ANN2SNN
t6 Reinforcement Learning: Deep Q Learning
t10 Propagation Pattern
t13 Neuromorphic Datasets Processing
t14 Classify DVS128 Gesture

Citation

If you use SpikingJelly in your work, please cite it as follows:

@misc{SpikingJelly,
	title = {SpikingJelly},
	author = {Fang, Wei and Chen, Yanqi and Ding, Jianhao and Chen, Ding and Yu, Zhaofei and Zhou, Huihui and Tian, Yonghong and other contributors},
	year = {2020},
	howpublished = {\url{https://github.com/fangwei123456/spikingjelly}},
	note = {Accessed: YYYY-MM-DD},
}

Contribution

You can read the issues and get the problems to be solved and latest development plans. We welcome all users to join the discussion of development plans, solve issues, and send pull requests.

About

Multimedia Learning Group, Institute of Digital Media (NELVT), Peking University and Peng Cheng Laboratory are the main developers of SpikingJelly.

PKUPCL

The list of developers can be found here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spikingjelly-0.0.0.0.6.tar.gz (135.6 kB view hashes)

Uploaded Source

Built Distribution

spikingjelly-0.0.0.0.6-py3-none-any.whl (177.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page