Skip to main content

Variational Neural Networks

Project description

Variational Neural Networks Pytorch

This repository contains a Pytorch implementation of Variational Neural Networks (VNNs) and image classification experiments for Variational Neural Networks paper.

The corresponding package contains layer implementations for VNNs and other used architectures. It can be installed using pip install vnn.

Bayesian Neural Networks (BNNs) provide a tool to estimate the uncertainty of a neural network by considering a distribution over weights and sampling different models for each input. In this paper, we propose a method for uncertainty estimation in neural networks called Variational Neural Network that, instead of considering a distribution over weights, generates parameters for the output distribution of a layer by transforming its inputs with learnable sub-layers. In uncertainty quality estimation experiments, we show that VNNs achieve better uncertainty quality than Monte Carlo Dropout or Bayes By Backpropagation methods.

Run

Use run_example.sh to train and evaluate a single model on MNIST. The corresponding reproducible capsule is available at CodeOcean.

Package

Use pip install vnn or python3 -m pip install vnn to install the package. The package includes only the layer implementations of VNNs, as well as dropout, functional and classic layers. These layers are implemented with the same interface, making it easy to implement different versions of your desired network by changing the class names.

An example of a simple convolutional network:

import torch
from vnn import VariationalConvolution, VariationalLinear

class Based(torch.nn.Module):

    def __init__(self, **kwargs) -> None:

        super().__init__()

        self.model = nn.Sequential(
            VariationalConvolution(1, 256, 9, 1, **kwargs),
            VariationalConvolution(256, 256, 9, 2, **kwargs),
            VariationalConvolution(256, 16, 4, 1, **kwargs),
            torch.nn.Flatten(start_dim=1),
            VariationalLinear(3 * 3 * 16, 10, **kwargs),
        )

    def forward(self, x):

        return self.model(x)

The same classic network:

import torch
from vnn.classic import ClassicConvolution, ClassicLinear

class Based(torch.nn.Module):

    def __init__(self, **kwargs) -> None:

        super().__init__()

        self.model = nn.Sequential(
            ClassicConvolution(1, 256, 9, 1, **kwargs),
            ClassicConvolution(256, 256, 9, 2, **kwargs),
            ClassicConvolution(256, 16, 4, 1, **kwargs),
            torch.nn.Flatten(start_dim=1),
            ClassicLinear(3 * 3 * 16, 10, **kwargs),
        )

    def forward(self, x):

        return self.model(x)

Or a generalized network class:

import torch
from vnn import VariationalConvolution, VariationalLinear
from vnn.classic import ClassicConvolution, ClassicLinear
from vnn.dropout import DropoutConvolution, DropoutLinear
from vnn.functional import FunctionalConvolution, FunctionalLinear

def create_based(Convolution, Linear):
    class Based(torch.nn.Module):

        def __init__(self, **kwargs) -> None:

            super().__init__()

            self.model = nn.Sequential(
                Convolution(1, 256, 9, 1, **kwargs),
                Convolution(256, 256, 9, 2, **kwargs),
                Convolution(256, 16, 4, 1, **kwargs),
                torch.nn.Flatten(start_dim=1),
                Linear(3 * 3 * 16, 10, **kwargs),
            )

        def forward(self, x):

            return self.model(x)

based_vnn = create_based(VariationalConvolution, VariationalLinear)
based_classic = create_based(ClassicConvolution, ClassicLinear)
based_dropout = create_based(DropoutConvolution, DropoutLinear)
based_functional = create_based(FunctionalConvolution, FunctionalLinear) # see hypermodels on how to use functional layers

Citation

If you use this work for your research, you can cite it as:

Library:

@article{oleksiienko2022vnntorchjax,
    title = {Variational Neural Networks implementation in Pytorch and JAX},
    author = {Oleksiienko, Illia and Tran, Dat Thanh and Iosifidis, Alexandros},
    journal = {Software Impacts},
    volume = {14},
    pages = {100431},
    year = {2022},
}

Paper:

@article{oleksiienko2023vnn,
  title={Variational Neural Networks}, 
  author = {Oleksiienko, Illia and Tran, Dat Thanh and Iosifidis, Alexandros},
  journal={arxiv:2207.01524}, 
  year={2023},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vnn-0.2.0.tar.gz (3.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vnn-0.2.0-py3-none-any.whl (11.0 kB view details)

Uploaded Python 3

File details

Details for the file vnn-0.2.0.tar.gz.

File metadata

  • Download URL: vnn-0.2.0.tar.gz
  • Upload date:
  • Size: 3.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.3

File hashes

Hashes for vnn-0.2.0.tar.gz
Algorithm Hash digest
SHA256 ede917439cc9325b9e434cae63dc63b47040c92055743d4a51abf6253b2b103d
MD5 73432d62bb99f9854657f7c491f15b40
BLAKE2b-256 7d1b681cb86e0d1090b8c7b738d8c23f8e089ac02835960ed245edf93af782e7

See more details on using hashes here.

File details

Details for the file vnn-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: vnn-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 11.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.3

File hashes

Hashes for vnn-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4db3f6955e2b8dcb48f55bac06e26952b5a2d57721bab7a68f915ef9e36f1978
MD5 de99ebd42bdc7130a30b27a2e7ba4cc9
BLAKE2b-256 a9594f5e316d879d8e176aaaee463a4123cfdb8aeca2f38ee8550faefda2583a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page