Variational Dropout and Complex-valued Neural Networks in pytorch
Project description
CplxModule
A lightweight extension for torch.nn
that adds layers and activations, which respect algebraic operations over the field of complex numbers, and implements real- and complex-valued Variational Dropout methods for weight sparsification. The complex-valued building blocks and Variational Dropout layers of both kinds can be seamlessly integrated into pytorch-based training pipelines. The package provides the toolset necessary to train, sparsify and fine-tune both real- and complex-valued models.
Documentation
For a high-level description of the implementation, functionality and useful code patterns, please refer to the following READMEs
- cplxmodule.nn the implemented complex-valued layers and their basic use
- cplxmodule.nn.relevance the plug-and-play layers for Variational Dropout and how to use them ([3], [4], [5]).
- cplxmodule.nn.masked supported masked layers for fine-tuning pruned networks and how to migrate parameters between classic
torch.nn
layers
Implementation
The core implementation of the complex-valued arithmetic and layers is based on careful tracking of transformations of real and imaginary parts of complex-valued tensors, and leverages differentiable computations of the real-valued pytorch backend.
The batch normalization and weight initialization layers are based on the ICLR 2018 paper by Chiheb Trabelsi et al. (2018) on Deep Complex Networks [1] and borrow ideas from their implementation (nn.init
, nn.modules.batchnorm
). The complex-valued magnitude-based Max pooling is based on the idea by Zhang et al. (2017) [6].
The implementations of the real-valued Variational Dropout and Automatic Relevance Determination are based on the profound works by Diederik Kingma et al. (2015) [2], Dmitry Molchanov et al. (2017) [3], and Valery Kharitonov et al. (2018) [4].
Complex-valued Bayesian sparsification layers are based on the research by Nazarov and Burnaev (2020) [5].
Installation
The essential dependencies of cplxmodule
are numpy
, torch
and scipy
, which can be installed via
# essential dependencies
# conda update -n base -c defaults conda
conda create -n cplxmodule "python>=3.7" pip numpy scipy "pytorch::pytorch" \
&& conda activate cplxmodule
Extra dependencies, that are used in tests and needed for development, can be added on top of the essentials. Check ONNX Runtime to see of your system is compatible.
conda activate cplxmodule
# extra deps for development
conda install -n cplxmodule matplotlib scikit-learn tqdm pytest "pytorch::torchvision" \
&& pip install black pre-commit
# ONNX (for compatible systems)
conda install -n cplxmodule onnx && pip install onnxruntime
The package itself can be installed this package with pip
:
conda activate cplxmodule
pip install cplxmodule
or from the git repo to get the latest version:
conda activate cplxmodule
pip install --upgrade git+https://github.com/ivannz/cplxmodule.git
or locally from the root of the locally cloned repo, if you prefer an editable developer install:
conda activate cplxmodule
# enable basic checks (codestyle, stray whitespace, eof newline)
pre-commit install
# editable install
pip install -e .
# run tests to verify installation (batchnorm test )
# XXX `test_batchnorm.py` depends on the precision of the outcome of SGD, hence
# may occasionally fail
# XXX A user warning concerning non-writable numpy array is expected
pytest
Additionally, you may want to study the following examples and test Variational Dropout:
conda activate cplxmodule
# test real- and complex-valued Bayesian sparsification layers
python tests/test_relevance.py
# showcase the train-sparisify-fine-tune staged pipeline on a basic
# real-valued CNN on MNIST
python tests/test_mnist.py
Citation
The proper citation for the real-valued Bayesian Sparsification layers from cplxmodule.nn.relevance.real
is either [3] (VD) or [4] (ARD). If you find the complex-valued Bayesian Sparsification layers from cplxmodule.nn.relevance.complex
useful in your research, please consider citing the following paper [5]:
@inproceedings{nazarov_bayesian_2020,
title = {Bayesian {Sparsification} of {Deep} {C}-valued {Networks}},
volume = {119},
url = {http://proceedings.mlr.press/v119/nazarov20a.html},
language = {en},
urldate = {2021-08-02},
booktitle = {International {Conference} on {Machine} {Learning}},
publisher = {PMLR},
author = {Nazarov, Ivan and Burnaev, Evgeny},
month = nov,
year = {2020},
note = {ISSN: 2640-3498},
pages = {7230--7242}
}
References
[1] Trabelsi, C., Bilaniuk, O., Zhang, Y., Serdyuk, D., Subramanian, S., Santos, J. F., Mehri, S., Rostamzadeh, N, Bengio, Y. & Pal, C. J. (2018). Deep complex networks. In International Conference on Learning Representations, 2018.
[2] Kingma, D. P., Salimans, T., & Welling, M. (2015). Variational dropout and the local reparameterization trick. In Advances in neural information processing systems (pp. 2575-2583).
[3] Molchanov, D., Ashukha, A., & Vetrov, D. (2017, August). Variational dropout sparsifies deep neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70 (pp. 2498-2507). JMLR.org
[4] Kharitonov, V., Molchanov, D., & Vetrov, D. (2018). Variational Dropout via Empirical Bayes. arXiv preprint arXiv:1811.00596.
[5] Nazarov, I., & Burnaev, E. (2020, November). Bayesian Sparsification of Deep C-valued Networks. In International Conference on Machine Learning (pp. 7230-7242). PMLR.
[6] Zhang, Z., Wang, H., Xu, F., & Jin, Y. Q. (2017). Complex-valued convolutional neural network and its application in polarimetric SAR image classification. IEEE Transactions on Geoscience and Remote Sensing, 55(12), 7177-7188.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file cplxmodule-2022.6.tar.gz
.
File metadata
- Download URL: cplxmodule-2022.6.tar.gz
- Upload date:
- Size: 69.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6fc890e2a2b9b6a39ce1ecb1ca298a771ea95848f8abde64e277b3bac244ae0c |
|
MD5 | b1234b957341a051d3043d4ebedf31c3 |
|
BLAKE2b-256 | 5aa6de8b67c3943327b30468f775294c14b26a1ea71b7c848d9fe72a6ff538b3 |
File details
Details for the file cplxmodule-2022.6-py3-none-any.whl
.
File metadata
- Download URL: cplxmodule-2022.6-py3-none-any.whl
- Upload date:
- Size: 58.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6e5b8f6d1c5dbc2dcc6221cec51fb6edcdb3c92ec4cc2d4289095184e7e09468 |
|
MD5 | eb1c421968c67c21a1c6950b12ae4146 |
|
BLAKE2b-256 | 8ea10ea7cd43cebd408110c95135fae0004f9a7e6e26f6f5118441a9847f02c1 |