Skip to main content

A PyTorch-based framework for building and training moment neural networks.

Project description

Table of content

moment-neural-network

The moment neural network is a type of second-order artificial neural network model designed to capture the nonlinear coupling of correlated activity of spiking neurons. In brief, the moment neural networks extend conventional rate-based artificial neural network models by incorporating the covariance of fluctuating neural activity. This repository provides a comprehensive framework for simulating and training moment neural networks based on the standard workflow of Pytorch.

For full details see publication: https://arxiv.org/abs/2305.13982

The architecture of this repository

  • mnn_core: core modules implementing the moment activation and other building blocks of MNN.
  • models: a module containging various network architectures for fast and convenient model construction
  • snn: modules for reconstructing SNN from MNN and for simulating the corresponding SNN in a flexible manner.
  • utils: a collection of useful utilities for training MNN (ANN compatible).

Dependencies

  • python 3
  • pytorch: 1.12.1
  • torchvision: 0.13.1
  • scipy: 1.7.3
  • pyyaml: 6.0
  • numpy: 1.22.3
  • CUDA (optional)

Getting Started

Quick start: three steps to run your first MNN model

The following provides a step-by-step instruction to train an MNN to learn MNIST image classification task with a multi-layer perceptron structure.

  1. Clone the repository to your local drive.

  2. Copy the demo files, ./example/mnist/mnist.py and ./example/mnist/mnist_config.yaml to the root directory.

  3. Create two directories, ./checkpoint/ (for saving trained model results) and ./data/ (for downloading the MNIST dataset).

  4. Run the following command to call the script named mnist.py with the config file specified through the option:

    python mnist.py --config=./mnist_config.yaml
    

After training is finished, you should find four files in the ./checkpoint/mnist/ folder:

  • Two '.ph' files which contain the trained model parameters.
  • One '.yaml' file which is a copy of the config file used for running the training the model.
  • One '.txt' log file that prints the standard output during training (such as model performance).
  • One directroy called mnn_net_snn_result that stores the simulation result of the SNN reconstructed from the trained MNN (if enabled).

Configure the MNN model

Let's review the content of mnist.yaml.

The MODEL section is for specifying the architecture of MNN. meta: meta information about model construction.

  • arch: specifies the model architecture. Currently only mlp-like architecture is available (arch: mnn_mlp).
  • mlp_type: indicates the kind of mlp to be built. For mnn_mlp, the model contains one input layer, arbitrary number of hidden layers, and a linear decoder. mnn_mlp: detailed model specification for mlp
  • structure: you can change the widths of each layer by modifying the values under this field.
  • num_class: specifies the output dimension. See mnn.models.mlp for under-the-hood details.

The CRITERION section indicate the training criterion such as the loss function. name: the name for the loss function. Currently supports ... source: the name of the directory where the loss function is defined. arg: input arguments to the loss function. The code will try to find the criterion from source that match the name and pass required args to it. See mnn_core.nn.criterion for under-the-hood details.

Similarly, the optimzer and data augmentation policy are defined under OPTIMIZER and DATAAUG_TRAIN/VAL, correspoding to the pytorch implementations (torch.optim and torchvision.transforms ).

There are some advanced options in the config file:

  • save_epoch_state: at the start of each epoch, the code will store the model parameters.
  • input_prepare: currently only flatten_poisson is valid. It means we first flatten input to a vector and regard it as independent Poisson rate code.
  • scale_factor: only valid if input_prepare is flatten_poisson, used to control input range.
  • is_classify: the task type, if False, the best model is determined by the epoch that has minimal loss.
  • background_noise: this value will add to the diagonal of input covariance (Can be helpful if input covariance is very weak or close to singular)

Configure additional training options via input arguments.

python main_script.py --config=./your_config_file.yaml --OPT=VALUE

Some examples of the OPT field:

  • seed: fix the seed for all RNGs used by the model. By default it is None (not fixed)
  • bs: batch size used in the data loader
  • dir: directory name for saving training data
  • save_name: the prefix of file name of training data
  • epochs: the number of epochs to train.
  • cpu: manually set device to CPU

I recommend you to read the func deploy_config() in utils.training_tools.general_prepare

Note all manual argument will be overwritten if the same keys are found in the provided your_config_file.yaml

Run simulations of the reconstructed SNN

We provide utility to automatically reconstruct SNN based on the trained MNN. A custom simulator of SNN is provided with GPU support but you may use any SNN simulator of your choice.

Customize your own MNN model

Custom dataset

Custom loss function

Custom model

Lead authors

License

This project is licensed under the Apache License 2.0 - see the LICENSE.md file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

moment_neural_network-0.1.0.tar.gz (60.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

moment_neural_network-0.1.0-py3-none-any.whl (73.3 kB view details)

Uploaded Python 3

File details

Details for the file moment_neural_network-0.1.0.tar.gz.

File metadata

  • Download URL: moment_neural_network-0.1.0.tar.gz
  • Upload date:
  • Size: 60.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for moment_neural_network-0.1.0.tar.gz
Algorithm Hash digest
SHA256 6eb13088625ab1f4c4347972b4fa239b6757fd484fc1ba85d51a0fd0998db115
MD5 da030fc63fccdbf701df0ac83dc78250
BLAKE2b-256 1d6ad986efcbebce664e481b8d95bab56b1eed3397ef8d72ff12993d7100633d

See more details on using hashes here.

File details

Details for the file moment_neural_network-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for moment_neural_network-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ef4635a79a5f81c0ada6258bde5495b014a0bcd573524dcdc2ff89f206ea470a
MD5 efc91a05d12a6338b766b796e5558a90
BLAKE2b-256 6cee005ad791271b26bd74dd9c7495c6bf39dff80b9e4165c012b2f4dc70b6bf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page