Skip to main content

Variational Dropout and Complex-valued Neural Networks in pytorch

Project description

CplxModule

A lightweight extension for torch.nn that adds layers and activations, which respect algebraic operations over the field of complex numbers, and implements real- and complex-valued Variational Dropout methods for weight sparsification. The complex-valued building blocks and Variational Dropout layers of both kinds can be seamlessly integrated into pytorch-based training pipelines. The package provides the toolset necessary to train, sparsify and fine-tune both real- and complex-valued models.

Documentation

For a high-level description of the implementation, functionality and useful code patterns, please refer to the following READMEs

Implementation

The core implementation of the complex-valued arithmetic and layers is based on careful tracking of transformations of real and imaginary parts of complex-valued tensors, and leverages differentiable computations of the real-valued pytorch backend.

The batch normalization and weight initialization layers are based on the ICLR 2018 paper by Chiheb Trabelsi et al. (2018) on Deep Complex Networks [1] and borrow ideas from their implementation (nn.init, nn.modules.batchnorm). The complex-valued magnitude-based Max pooling is based on the idea by Zhang et al. (2017) [6].

The implementations of the real-valued Variational Dropout and Automatic Relevance Determination are based on the profound works by Diederik Kingma et al. (2015) [2], Dmitry Molchanov et al. (2017) [3], and Valery Kharitonov et al. (2018) [4].

Complex-valued Bayesian sparsification layers are based on the research by Nazarov and Burnaev (2020) [5].

Installation

The essential dependencies of cplxmodule are numpy, torch and scipy, which can be installed via

# essential dependencies
# conda update -n base -c defaults conda
conda create -n cplxmodule "python>=3.7" pip numpy scipy "pytorch::pytorch" \
  && conda activate cplxmodule

Extra dependencies, that are used in tests and needed for development, can be added on top of the essentials. Check ONNX Runtime to see of your system is compatible.

conda activate cplxmodule

# extra deps for development
conda install -n cplxmodule matplotlib scikit-learn tqdm pytest "pytorch::torchvision" \
  && pip install black pre-commit

# ONNX (for compatible systems)
conda install -n cplxmodule onnx && pip install onnxruntime

The package itself can be installed this package with pip:

conda activate cplxmodule

pip install cplxmodule

or from the git repo to get the latest version:

conda activate cplxmodule

pip install --upgrade git+https://github.com/ivannz/cplxmodule.git

or locally from the root of the locally cloned repo, if you prefer an editable developer install:

conda activate cplxmodule

# enable basic checks (codestyle, stray whitespace, eof newline)
pre-commit install

# editable install
pip install -e .

# run tests to verify installation (batchnorm test )
# XXX `test_batchnorm.py` depends on the precision of the outcome of SGD, hence
#  may occasionally fail
# XXX A user warning concerning non-writable numpy array is expected
pytest

Additionally, you may want to study the following examples and test Variational Dropout:

conda activate cplxmodule

# test real- and complex-valued Bayesian sparsification layers
python tests/test_relevance.py

# showcase the train-sparisify-fine-tune staged pipeline on a basic
#  real-valued CNN on MNIST
python tests/test_mnist.py

Citation

The proper citation for the real-valued Bayesian Sparsification layers from cplxmodule.nn.relevance.real is either [3] (VD) or [4] (ARD). If you find the complex-valued Bayesian Sparsification layers from cplxmodule.nn.relevance.complex useful in your research, please consider citing the following paper [5]:

@inproceedings{nazarov_bayesian_2020,
    title = {Bayesian {Sparsification} of {Deep} {C}-valued {Networks}},
    volume = {119},
    url = {http://proceedings.mlr.press/v119/nazarov20a.html},
    language = {en},
    urldate = {2021-08-02},
    booktitle = {International {Conference} on {Machine} {Learning}},
    publisher = {PMLR},
    author = {Nazarov, Ivan and Burnaev, Evgeny},
    month = nov,
    year = {2020},
    note = {ISSN: 2640-3498},
    pages = {7230--7242}
}

References

[1] Trabelsi, C., Bilaniuk, O., Zhang, Y., Serdyuk, D., Subramanian, S., Santos, J. F., Mehri, S., Rostamzadeh, N, Bengio, Y. & Pal, C. J. (2018). Deep complex networks. In International Conference on Learning Representations, 2018.

[2] Kingma, D. P., Salimans, T., & Welling, M. (2015). Variational dropout and the local reparameterization trick. In Advances in neural information processing systems (pp. 2575-2583).

[3] Molchanov, D., Ashukha, A., & Vetrov, D. (2017, August). Variational dropout sparsifies deep neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70 (pp. 2498-2507). JMLR.org

[4] Kharitonov, V., Molchanov, D., & Vetrov, D. (2018). Variational Dropout via Empirical Bayes. arXiv preprint arXiv:1811.00596.

[5] Nazarov, I., & Burnaev, E. (2020, November). Bayesian Sparsification of Deep C-valued Networks. In International Conference on Machine Learning (pp. 7230-7242). PMLR.

[6] Zhang, Z., Wang, H., Xu, F., & Jin, Y. Q. (2017). Complex-valued convolutional neural network and its application in polarimetric SAR image classification. IEEE Transactions on Geoscience and Remote Sensing, 55(12), 7177-7188.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cplxmodule-2022.6.tar.gz (69.8 kB view details)

Uploaded Source

Built Distribution

cplxmodule-2022.6-py3-none-any.whl (58.1 kB view details)

Uploaded Python 3

File details

Details for the file cplxmodule-2022.6.tar.gz.

File metadata

  • Download URL: cplxmodule-2022.6.tar.gz
  • Upload date:
  • Size: 69.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.10.4

File hashes

Hashes for cplxmodule-2022.6.tar.gz
Algorithm Hash digest
SHA256 6fc890e2a2b9b6a39ce1ecb1ca298a771ea95848f8abde64e277b3bac244ae0c
MD5 b1234b957341a051d3043d4ebedf31c3
BLAKE2b-256 5aa6de8b67c3943327b30468f775294c14b26a1ea71b7c848d9fe72a6ff538b3

See more details on using hashes here.

File details

Details for the file cplxmodule-2022.6-py3-none-any.whl.

File metadata

  • Download URL: cplxmodule-2022.6-py3-none-any.whl
  • Upload date:
  • Size: 58.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.11.3 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.10.4

File hashes

Hashes for cplxmodule-2022.6-py3-none-any.whl
Algorithm Hash digest
SHA256 6e5b8f6d1c5dbc2dcc6221cec51fb6edcdb3c92ec4cc2d4289095184e7e09468
MD5 eb1c421968c67c21a1c6950b12ae4146
BLAKE2b-256 8ea10ea7cd43cebd408110c95135fae0004f9a7e6e26f6f5118441a9847f02c1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page