Skip to main content

PyTorch implementation of Hebbian "Bio-Learning" convolutional layers

Project description

BioPyTorch

PyTorch implementation of the learning rule proposed in "Unsupervised learning by competing hidden units", D. Krotov, J. J. Hopfield, 2019, https://www.pnas.org/content/116/16/7723

Installation

Install with pip:

pip install biopytorch

Usage

The package provides two layers, BioLinear and BioConv2d, that respectively mirror the features of nn.Linear and nn.Conv2d from PyTorch, with the added support of training with the alternative rule proposed by Krotov & Hopfield.

They share all the same parameters of their analogues (except for BioConv2d, which currently does not support the use of bias). To execute a single update step, call the method training_step.

See the example notebook in notebooks for more details.

Other files

  • dev/bioconv2d_dev.py contains an alternative implementation of BioConv2d using F.unfold. The performance is significantly worse (especially for memory), so it should not be used in practice. However, the algorithm is easier to follow, and can be used to get a better understanding of the Krotov learning rule.
  • slides contains a few explanatory slides
  • notebooks: examples

Benchmark

Hyperparameters ($p$, $k$, $\Delta$ - for their meaning, see the slides, or the docstrings) are optimized with respect to the validation accuracy of classification on the CIFAR-10 dataset, using the Optuna library.

Specifically, the architecture (taken from [2]) is as follows:

The (Classifier) segment is inserted in different positions - after (1), (2), ... - such that the change in performance given by deeper layers may be measured.

Depending on the number of hebbian layers preceding the (Classifier), the performance obtained with the best hyperparameters found is as follows:

#layers 1 2 3 4 5
Accuracy (val) 69.20 67.13 64.91 59.83 46.25
Accuracy (test) 67.06 65.22 63.08 58.86 45.45
$p$ 2 8 8 8 8
$k$ 9 3 5 7 2
$\Delta$ .08 .34 .25 .235 .335
Dropout .2 .25 .05 .1 .1
Params 195k 302k 387k 804k 1.475M

$p$, $k$ and $\Delta$ are the same for all the BioConv2d layers. When the full architecture is trained, different hyperparameters are used for the BioLinear layer. However, for the best run, they are exactly equal to the ones already used for the previous BioConv2d, which are reported in the table.

Note that performance is slightly better than the results obtained in [2], here reported for reference:

#layers 1 2 3 4 5
Accuracy (this) 67.06 65.22 63.08 58.86 45.45
Accuracy ([2]) 63.92 63.81 58.28 52.99 41.78

A full report on the hyperparameter optimization is available on wandb.

Sources

[1] Krotov, Hopfield, "Unsupervised learning by competing hidden units", 2019

[2] Amato et al., "Hebbian Learning Meets Deep Convolutional Neural Networks", 2019

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

biopytorch-0.0.1.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

biopytorch-0.0.1-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file biopytorch-0.0.1.tar.gz.

File metadata

  • Download URL: biopytorch-0.0.1.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.5

File hashes

Hashes for biopytorch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 817260b83804011daf2cebaf354d24780156438e27ec37310d05a19a9f881f96
MD5 4a13d863c1378daf46419fe2892944fb
BLAKE2b-256 31072af18d449139eef664333b372378a1f6df284bce80630b77db15ab64875b

See more details on using hashes here.

File details

Details for the file biopytorch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: biopytorch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 13.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.5

File hashes

Hashes for biopytorch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d6b6faa15f0ca81015904daeb51785bd83130e26b63d615f4ee245df6d71c1ba
MD5 569bc4534d68cf6a8af80d8f52abfebf
BLAKE2b-256 07c09dc6c5638fd528a701638d62b6d8f785084daa86c8fdfaa7f5eef8c3d283

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page