Skip to main content

Lightweight Bayesian deep learning library for fast prototyping based on PyTorch

Project description

BayesTorch

Python version: 3.6 | 3.7 | 3.8 | 3.9 | 3.10 License Code style: black Imports: isort pre-commit PyPI version

Welcome to bayestorch, a lightweight Bayesian deep learning library for fast prototyping based on PyTorch. It provides the basic building blocks for the following Bayesian inference algorithms:


💡 Key features

  • Low-code definition of Bayesian (or partially Bayesian) models
  • Support for custom neural network layers
  • Support for custom prior/posterior distributions
  • Support for layer/parameter-wise prior/posterior distributions
  • Support for composite prior/posterior distributions
  • Highly modular object-oriented design
  • User-friendly and easily extensible APIs
  • Detailed API documentation

🛠️️ Installation

Using Pip

First of all, install Python 3.6 or later. Open a terminal and run:

pip install bayestorch

From source

First of all, install Python 3.6 or later. Clone or download and extract the repository, navigate to <path-to-repository>, open a terminal and run:

pip install -e .

▶️ Quickstart

Here are a few code snippets showcasing some key features of the library. For complete training loops, please refer to examples/mnist and examples/regression.

Bayesian model trainable via Bayes by Backprop

from torch.nn import Linear

from bayestorch.distributions import (
    get_mixture_log_scale_normal,
    get_softplus_inv_scale_normal,
)
from bayestorch.nn import VariationalPosteriorModule


# Define model
model = Linear(5, 1)

# Define log scale normal mixture prior over the model parameters
prior_builder, prior_kwargs = get_mixture_log_scale_normal(
    model.parameters(),
    weights=[0.75, 0.25],
    locs=(0.0, 0.0),
    log_scales=(-1.0, -6.0)
)

# Define inverse softplus scale normal posterior over the model parameters
posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal(
    model.parameters(), loc=0.0, softplus_inv_scale=-7.0, requires_grad=True,
)

# Define Bayesian model trainable via Bayes by Backprop
model = VariationalPosteriorModule(
    model, prior_builder, prior_kwargs, posterior_builder, posterior_kwargs
)

Partially Bayesian model trainable via Bayes by Backprop

from torch.nn import Linear

from bayestorch.distributions import (
    get_mixture_log_scale_normal,
    get_softplus_inv_scale_normal,
)
from bayestorch.nn import VariationalPosteriorModule


# Define model
model = Linear(5, 1)

# Define log scale normal mixture prior over `model.weight`
prior_builder, prior_kwargs = get_mixture_log_scale_normal(
    [model.weight],
    weights=[0.75, 0.25],
    locs=(0.0, 0.0),
    log_scales=(-1.0, -6.0)
)

# Define inverse softplus scale normal posterior over `model.weight`
posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal(
    [model.weight], loc=0.0, softplus_inv_scale=-7.0, requires_grad=True,
)

# Define partially Bayesian model trainable via Bayes by Backprop
model = VariationalPosteriorModule(
    model, prior_builder, prior_kwargs,
    posterior_builder, posterior_kwargs, [model.weight],
)

Composite prior

from torch.distributions import Independent
from torch.nn import Linear

from bayestorch.distributions import (
    CatDistribution,
    get_laplace,
    get_normal,
    get_softplus_inv_scale_normal,
)
from bayestorch.nn import VariationalPosteriorModule


# Define model
model = Linear(5, 1)

# Define normal prior over `model.weight`
weight_prior_builder, weight_prior_kwargs = get_normal(
    [model.weight],
    loc=0.0,
    scale=1.0,
    prefix="weight_",
)

# Define Laplace prior over `model.bias`
bias_prior_builder, bias_prior_kwargs = get_laplace(
    [model.bias],
    loc=0.0,
    scale=1.0,
    prefix="bias_",
)

# Define composite prior over the model parameters
prior_builder = (
    lambda **kwargs: CatDistribution([
        Independent(weight_prior_builder(**kwargs), 1),
        Independent(bias_prior_builder(**kwargs), 1),
    ])
)
prior_kwargs = {**weight_prior_kwargs, **bias_prior_kwargs}

# Define inverse softplus scale normal posterior over the model parameters
posterior_builder, posterior_kwargs = get_softplus_inv_scale_normal(
    model.parameters(), loc=0.0, softplus_inv_scale=-7.0, requires_grad=True,
)

# Define Bayesian model trainable via Bayes by Backprop
model = VariationalPosteriorModule(
    model, prior_builder, prior_kwargs, posterior_builder, posterior_kwargs,
)

📧 Contact

luca.dellalib@gmail.com


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayestorch-0.0.3.tar.gz (28.2 kB view details)

Uploaded Source

Built Distribution

bayestorch-0.0.3-py3-none-any.whl (44.2 kB view details)

Uploaded Python 3

File details

Details for the file bayestorch-0.0.3.tar.gz.

File metadata

  • Download URL: bayestorch-0.0.3.tar.gz
  • Upload date:
  • Size: 28.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for bayestorch-0.0.3.tar.gz
Algorithm Hash digest
SHA256 f597e41a4567b084985b367880f951d4be5562aa50368b2e9cb7b68a09dcff5d
MD5 0f3d0902d4cf65ef1bbc4c97d19aa8da
BLAKE2b-256 73497484c24ee265d54ad4fbda7c9c4b5d177a428fb36bfa9379381d2c0dd9b7

See more details on using hashes here.

File details

Details for the file bayestorch-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: bayestorch-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 44.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for bayestorch-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 55dc0c1c9f39f8ce4f9ece586760086004a55326dfb3c594f40b5b5d52118c6d
MD5 40b9d6b4d48173c5287692acbc0972e9
BLAKE2b-256 70ccb3959b49f22f3bc18380fba34dabeb0b4517cbe5d28c2b08f0e9ce768706

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page