Skip to main content

Library for BNNs

Project description

torch blue Logo torch blue Logo

A PyTorch-like library for Bayesian learning and uncertainty estimation

Python License: BSD-3 PyPI codecov docs JOSS


torch_blue provides a simple way for non-expert users to implement and train Bayesian Neural Networks (BNNs). Currently, it only supports Variational Inference (VI), but will hopefully grow and expand in the future. To make the user experience as easy as possible most components mirror components from PyTorch.

Installation

We heavily recommend installing torch_blue in a dedicated Python3.10+ virtual environment. You can install torch_blue from PyPI:

$ pip install torch-blue

Alternatively, you can install torch_blue locally. To achieve this, there are two steps you need to follow:

  1. Clone the repository
$ git clone https://github.com/RAI-SCC/torch_blue
  1. Install the code locally
$ pip install -e .

To get the development dependencies, run:

$ pip install -e .[dev]

For additional dependencies required if you want to run scripts from the scripts directory, run:

$ pip install -e .[scripts]

Documentation

Documentation is available online at readthedocs.

Quickstart

This Quickstart guide assumes basic familiarity with PyTorch and knowledge of how to implement the intended model in it. For a (potentially familiar) example see scripts/mnist_tutorial (as jupyter notebook with comments, or pure python script), which contains a copy of the PyTorch Quickstart tutorial modified to train a BNN with variational inference.

Five levels are introduced in this guide:

  • Level 1: PyTorch-Module auto-conversion
  • Level 2: Simple sequential layer stacks
  • Level 3: Customizing Bayesian assumptions and VI kwargs
  • Level 4: Non-sequential models and log probabilities
  • Level 5: Custom modules with weights

Level 1

For simple usage and convenience torch_blue provides the option to convert PyTorch models into Bayesian torch_blue models. Given a model represented by a single PyTorch nn.Module (and any number of submodules) conversion is performed by calling convert_to_vimodule:

from torch_blue.vi import convert_to_vimodule

convert_to_vimodule(model)

Note that many inplace operations, e.g., +=, -=, *=, /=, cannot be used in torch_blue modules for compatibility with PyTorchs vmap. As long as your model functions with vmap auto-conversion should work. If you encounter further problems please open an issue on GitHub.

[!IMPORTANT] convert_to_vimodule is an inplace operation. Additionally, it has several advanced options to control the conversion and the resulting model. Setting the prior and variational distribution is discussed in Level 3. Further options to keep pre-initialized weights and exclude certain layers from conversion are described in its documentation.

Additionally, the loss must be replaced. To start out, use vi.KullbackLeiblerLoss, which requires a Distribution with self.is_predictive_distribution=True and the size of the training dataset (this is important for balancing of assumptions and data). Choose your Distribution from the table below based on the loss you would use in PyTorch.

[!IMPORTANT] KullbackLeiblerLoss requires the length of the dataset, not the dataloader, which is just the number of batches.

PyTorch vi replacement
from vi.distributions
nn.MSELoss MeanFieldNormal
nn.CrossEntropyLoss Categorical

Level 2

Many parts of a neural network remain completely unchanged when turning it into a BNN. Indeed, only Modules containing nn.Parameters, need to be changed. Therefore, if all PyTorch Modules that have weights and should be Bayesian have equivalents in this package (see table below) should be relatively straightforward.

PyTorch vi replacement
nn.Linear VILinear
nn.Conv1d VIConv1d
nn.Conv2d VIConv2d
nn.Conv3d VIConv3d
nn.Transformer (including sublayers) VITransformer (including sublayers)

Any custom modules should inherit from vi.VIModule instead of nn.Module. Then replace all layers containing parameters as shown in the table above. For basic usage initialize these modules with the same arguments as their PyTorch equivalent. For advanced usage see Quickstart: Level 3. Many other layers can be included as-is. In particular activation functions, pooling, and padding (even dropout, though they should not be necessary since the prior acts as regularization). Currently, recurrent and transposed convolution layers are not supported. Normalization layers may have parameters depending on their setting, but can likely be left non-Bayesian. The loss needs to be adapted as described in Level 1.

Level 3

While the interface of VIModules is kept intentionally similar to PyTorch, there are additional arguments that customize the Bayesian assumptions that all provided layers accept and custom modules should generally accept and pass on to submodules:

  • variational_distribution (Distribution): defines the weight distribution and variational parameters. The default MeanFieldNormal assumes normal distributed, uncorrelated weights described by a mean and a standard deviation. While there are currently no alternatives the initial value of the standard deviation can be customized here.
  • prior (Distribution): defines the assumptions on the weight distribution and acts as regularizer. The default MeanFieldNormal assumes normal distributed, uncorrelated weights with mean 0 and standard deviation 1 (also known as a standard normal prior). Mean and standard deviation can be adapted here. Particularly reducing the standard deviation may help convergence at the risk of an overconfident model. Other available priors:
    • NonBayesian/UniformPrior: Under this prior all weight values are equally likely. While not recommended for Bayesian models this can be used in combination with a NonBayesian variational and predictive distribution to recover non-Bayesian training (useful for debugging or obtaining a baseline).
    • BasicQuietPrior: An experimental prior that correlates mean and standard deviation to disincentivize noisy weights
  • rescale_prior (bool): Experimental. Scales the prior similar to Kaiming-initialization. May help with convergence, but may lead to overconfidence. Current research.
  • prior_initialization (bool): Experimental. Initialize parameters from the prior instead of according to standard non-Bayesian methods. May lead to much faster convergence, but can cause the issues Kaiming-initialization counteracts unless rescale_prior is also set to True. Current research.
  • return_log_probs (bool): This is the topic of Quickstart: Level 4.

Level 4

For more advanced models one feature of Variational Inference (VI) needs to be taken into account. Generally, a loss for VI will require the log probability of the actually used weights (which are sampled on each forward pass) in the variational and prior distribution. Since it is quite inefficient to save the samples these log probabilities are evaluated during the forward pass and returned by the model. Since this is only necessary for training it can be controlled with the argument return_log_probs. Once the model is initialized this flag can be changed by setting VIModule.return_log_probs, which either enables (True) or disables (False) the returning of the log probabilities for all submodules.

While torch_blue calculates and aggregates log probs internally, this is handled by the outermost VIModule. This module will not have the expected output signature when returning log probs, but instead return a VIReturn object. This class is PyTorch Tensor that also contains log prob information in its additional log_probs attribute. This is the format torch_blue losses expect. Therefore, if you feed the output directly into a loss there should be no issues. While all PyTorch tensor operations can be performed on VIReturns many will delete the log prob information and transform the object back into a Tensor. This needs to be considered when performing further operations on the model output. The simplest way to avoid issues is to wrap all operations - except the loss - in a VIModule since log prob aggregation is only performed by the outermost module. For deployment return_log_probs should be set to False. If multiple Tensors are returned by the model, each will carry all log probs.

[!NOTE] Always make sure your outermost module is a VIModule and keep in mind that the output of that module will be a VIReturn object, which behaves like a Tensor, but carries weight log probabilities, if return_log_probs == True. Losses in torch_blue expect this format.

[!NOTE] Due to Autosampling all output Tensors, i.e. each VIReturn in the model output and the Tensor containing the log probs has an additional dimension at the beginning representing the multiple samples necessary to properly evaluate the stochastic forward pass. This is only relevant for VIModules that are not contained within other VIModules. Loss functions are designed to expect and handle this output format, i.e. you can simply feed the model output into the loss and everything will work.

Level 5

Creating VIModules with Bayesian weights - which are typically called random variables in documentation and code - is arguably simpler than in PyTorch. Since a different number of weight matrices needs to be created based on the variational distribution, the process is completely automated. For VIModules without weights super().__init__ is called without arguments. Modules with random variables expect VIkwargs (which you should be familiar with from Level 3), but defaults are used if non are passed. More importantly, VIModules with weights call super().__init__ with the argument variable_shapes. The keys of this dictionary are the names of the random variables and the values the shapes of the weight matrices as tuple or list. The value may also be set to None, which will always be the value returned for that variable.

The insertion order of this dictionary matters, as it becomes the order of the names in the module attribute random_variables. random_variables, the shapes, and a similar attribute of the variational distribution call distribution_parameters are used to dynamically create the weight matrices. The weight matrices can be accesses as attributes of the module, which will cause a sample to be drawn and its log prob to be stored if needed.

Should you need to access the weight tensors directly you can use getattr and derive the name using the method variational_parameter_name.

[!IMPORTANT] Every access of the weights will yield a new sample and log probability to be stored. Aggregation of multiple log probs is handled internally, but unnecessary calls will distort the result.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_blue-1.0.0.tar.gz (49.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_blue-1.0.0-py3-none-any.whl (54.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_blue-1.0.0.tar.gz.

File metadata

  • Download URL: torch_blue-1.0.0.tar.gz
  • Upload date:
  • Size: 49.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torch_blue-1.0.0.tar.gz
Algorithm Hash digest
SHA256 2889919f086815f68d6de7a7576cff8d6020c74ec476938305db2efb5beec589
MD5 e0f717ae9b02f050ad6855fe9883523b
BLAKE2b-256 dae695264cefd9ec56f87340a33e6c529436ac6fa910bc553d0cd3ea535fdc0b

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_blue-1.0.0.tar.gz:

Publisher: release.yaml on RAI-SCC/torch_blue

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torch_blue-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torch_blue-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 54.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for torch_blue-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 77702ef4fc7373c7323e8717ed3c0197af7f3be82f25d6dc48dc890ba2427ffb
MD5 db4314f0e33b4fae2cad2f196ad6dc05
BLAKE2b-256 aac23385adb70739b189aab22776e98f5150a5bdc7c95d7229c6b930c67908e2

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_blue-1.0.0-py3-none-any.whl:

Publisher: release.yaml on RAI-SCC/torch_blue

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page