A PyTorch-based framework for building and training moment neural networks.
Project description
Table of content
- moment-neural-network
- Dependencies
- Getting Started
- Customize your own MNN model
- Lead authors
- License
moment-neural-network
The moment neural network is a type of second-order artificial neural network model designed to capture the nonlinear coupling of correlated activity of spiking neurons. In brief, the moment neural networks extend conventional rate-based artificial neural network models by incorporating the covariance of fluctuating neural activity. This repository provides a comprehensive framework for simulating and training moment neural networks based on the standard workflow of Pytorch.
For full details see publication: https://arxiv.org/abs/2305.13982
The architecture of this repository
mnn_core: core modules implementing the moment activation and other building blocks of MNN.models: a module containging various network architectures for fast and convenient model constructionsnn: modules for reconstructing SNN from MNN and for simulating the corresponding SNN in a flexible manner.utils: a collection of useful utilities for training MNN (ANN compatible).
Dependencies
- python 3
- pytorch: 1.12.1
- torchvision: 0.13.1
- scipy: 1.7.3
- pyyaml: 6.0
- numpy: 1.22.3
- CUDA (optional)
Getting Started
Quick start: three steps to run your first MNN model
The following provides a step-by-step instruction to train an MNN to learn MNIST image classification task with a multi-layer perceptron structure.
-
Clone the repository to your local drive.
-
Copy the demo files, ./example/mnist/mnist.py and ./example/mnist/mnist_config.yaml to the root directory.
-
Create two directories, ./checkpoint/ (for saving trained model results) and ./data/ (for downloading the MNIST dataset).
-
Run the following command to call the script named
mnist.pywith the config file specified through the option:python mnist.py --config=./mnist_config.yaml
After training is finished, you should find four files in the ./checkpoint/mnist/ folder:
- Two '.ph' files which contain the trained model parameters.
- One '.yaml' file which is a copy of the config file used for running the training the model.
- One '.txt' log file that prints the standard output during training (such as model performance).
- One directroy called
mnn_net_snn_resultthat stores the simulation result of the SNN reconstructed from the trained MNN (if enabled).
Configure the MNN model
Let's review the content of mnist.yaml.
The MODEL section is for specifying the architecture of MNN.
meta: meta information about model construction.
arch: specifies the model architecture. Currently only mlp-like architecture is available (arch: mnn_mlp).mlp_type: indicates the kind of mlp to be built. Formnn_mlp, the model contains one input layer, arbitrary number of hidden layers, and a linear decoder.mnn_mlp: detailed model specification for mlpstructure: you can change the widths of each layer by modifying the values under this field.num_class: specifies the output dimension. Seemnn.models.mlpfor under-the-hood details.
The CRITERION section indicate the training criterion such as the loss function.
name: the name for the loss function. Currently supports ...
source: the name of the directory where the loss function is defined.
arg: input arguments to the loss function.
The code will try to find the criterion from source that match the name and pass required args to it.
See mnn_core.nn.criterion for under-the-hood details.
Similarly, the optimzer and data augmentation policy are defined under OPTIMIZER and DATAAUG_TRAIN/VAL, correspoding to the pytorch implementations (torch.optim and torchvision.transforms ).
There are some advanced options in the config file:
save_epoch_state: at the start of each epoch, the code will store the model parameters.input_prepare: currently only flatten_poisson is valid. It means we first flatten input to a vector and regard it as independent Poisson rate code.scale_factor: only valid ifinput_prepareis flatten_poisson, used to control input range.is_classify: the task type, ifFalse, the best model is determined by the epoch that has minimal loss.background_noise: this value will add to the diagonal of input covariance (Can be helpful if input covariance is very weak or close to singular)
Configure additional training options via input arguments.
python main_script.py --config=./your_config_file.yaml --OPT=VALUE
Some examples of the OPT field:
seed: fix the seed for all RNGs used by the model. By default it isNone(not fixed)bs: batch size used in the data loaderdir: directory name for saving training datasave_name: the prefix of file name of training dataepochs: the number of epochs to train.cpu: manually set device to CPU
I recommend you to read the func deploy_config() in utils.training_tools.general_prepare
Note all manual argument will be overwritten if the same keys are found in the provided your_config_file.yaml
Run simulations of the reconstructed SNN
We provide utility to automatically reconstruct SNN based on the trained MNN. A custom simulator of SNN is provided with GPU support but you may use any SNN simulator of your choice.
Customize your own MNN model
Custom dataset
Custom loss function
Custom model
Lead authors
- Zhichao Zhu - Chief Architect - Zhichao Zhu
- Yang Qi - Lead Algorithm Design - Yang Qi
License
This project is licensed under the Apache License 2.0 - see the LICENSE.md file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file moment_neural_network-0.1.0.tar.gz.
File metadata
- Download URL: moment_neural_network-0.1.0.tar.gz
- Upload date:
- Size: 60.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6eb13088625ab1f4c4347972b4fa239b6757fd484fc1ba85d51a0fd0998db115
|
|
| MD5 |
da030fc63fccdbf701df0ac83dc78250
|
|
| BLAKE2b-256 |
1d6ad986efcbebce664e481b8d95bab56b1eed3397ef8d72ff12993d7100633d
|
File details
Details for the file moment_neural_network-0.1.0-py3-none-any.whl.
File metadata
- Download URL: moment_neural_network-0.1.0-py3-none-any.whl
- Upload date:
- Size: 73.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ef4635a79a5f81c0ada6258bde5495b014a0bcd573524dcdc2ff89f206ea470a
|
|
| MD5 |
efc91a05d12a6338b766b796e5558a90
|
|
| BLAKE2b-256 |
6cee005ad791271b26bd74dd9c7495c6bf39dff80b9e4165c012b2f4dc70b6bf
|