PyTorch implementation of Hebbian "Bio-Learning" convolutional layers
Project description
BioPyTorch
PyTorch implementation of the learning rule proposed in "Unsupervised learning by competing hidden units", D. Krotov, J. J. Hopfield, 2019, https://www.pnas.org/content/116/16/7723
Installation
Install with pip
:
pip install biopytorch
Usage
The package provides two layers, BioLinear
and BioConv2d
, that respectively mirror the features of nn.Linear
and nn.Conv2d
from PyTorch, with the added support of training with the alternative rule proposed by Krotov & Hopfield.
They share all the same parameters of their analogues (except for BioConv2d
, which currently does not support the use of bias
). To execute a single update step, call the method training_step
.
See the example notebook in notebooks
for more details.
Other files
dev/bioconv2d_dev.py
contains an alternative implementation ofBioConv2d
usingF.unfold
. The performance is significantly worse (especially for memory), so it should not be used in practice. However, the algorithm is easier to follow, and can be used to get a better understanding of the Krotov learning rule.slides
contains a few explanatory slidesnotebooks
: examples
Benchmark
Hyperparameters ($p$, $k$, $\Delta$ - for their meaning, see the slides
, or the docstrings) are optimized with respect to the validation accuracy of classification on the CIFAR-10 dataset, using the Optuna
library.
Specifically, the architecture (taken from [2]) is as follows:
The (Classifier)
segment is inserted in different positions - after (1)
, (2)
, ... - such that the change in performance given by deeper layers may be measured.
Depending on the number of hebbian layers preceding the (Classifier)
, the performance obtained with the best hyperparameters found is as follows:
#layers | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
Accuracy (val) | 69.20 | 67.13 | 64.91 | 59.83 | 46.25 |
Accuracy (test) | 67.06 | 65.22 | 63.08 | 58.86 | 45.45 |
$p$ | 2 | 8 | 8 | 8 | 8 |
$k$ | 9 | 3 | 5 | 7 | 2 |
$\Delta$ | .08 | .34 | .25 | .235 | .335 |
Dropout | .2 | .25 | .05 | .1 | .1 |
Params | 195k | 302k | 387k | 804k | 1.475M |
$p$, $k$ and $\Delta$ are the same for all the BioConv2d
layers.
When the full architecture is trained, different hyperparameters are used for the BioLinear layer. However, for the best run, they are exactly equal to the ones already used for the previous BioConv2d
, which are reported in the table.
Note that performance is slightly better than the results obtained in [2], here reported for reference:
#layers | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
Accuracy (this) | 67.06 | 65.22 | 63.08 | 58.86 | 45.45 |
Accuracy ([2]) | 63.92 | 63.81 | 58.28 | 52.99 | 41.78 |
A full report on the hyperparameter optimization is available on wandb.
Sources
[1] Krotov, Hopfield, "Unsupervised learning by competing hidden units", 2019
[2] Amato et al., "Hebbian Learning Meets Deep Convolutional Neural Networks", 2019
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file biopytorch-0.0.1.tar.gz
.
File metadata
- Download URL: biopytorch-0.0.1.tar.gz
- Upload date:
- Size: 13.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 817260b83804011daf2cebaf354d24780156438e27ec37310d05a19a9f881f96 |
|
MD5 | 4a13d863c1378daf46419fe2892944fb |
|
BLAKE2b-256 | 31072af18d449139eef664333b372378a1f6df284bce80630b77db15ab64875b |
File details
Details for the file biopytorch-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: biopytorch-0.0.1-py3-none-any.whl
- Upload date:
- Size: 13.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.4 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d6b6faa15f0ca81015904daeb51785bd83130e26b63d615f4ee245df6d71c1ba |
|
MD5 | 569bc4534d68cf6a8af80d8f52abfebf |
|
BLAKE2b-256 | 07c09dc6c5638fd528a701638d62b6d8f785084daa86c8fdfaa7f5eef8c3d283 |