Skip to main content

laplace - Laplace approximations for deep learning

Project description

Laplace

Main

The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations. The library documentation is available at https://aleximmer.github.io/Laplace.

There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:

@article{daxberger2021laplace,
  title={Laplace Redux--Effortless Bayesian Deep Learning},
  author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
          and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
  journal={arXiv preprint arXiv:2106.14806},
  year={2021}
}

Setup

We assume python3.8 since the package was developed with that version. To install laplace with pip, run the following:

pip install laplace-torch

For development purposes, clone the repository and then install:

# or after cloning the repository for development
pip install -e .
# run tests
pip install -e .[tests]
pytest tests/

Structure

The laplace package consists of two main components:

  1. The subclasses of laplace.BaseLaplace that implement different sparsity structures: different subsets of weights ('all' and 'last_layer') and different structures of the Hessian approximation ('full', 'kron', and 'diag'). This results in six currently available options: laplace.FullLaplace, laplace.KronLaplace, laplace.DiagLaplace, and the corresponding last-layer variations laplace.FullLLLaplace, laplace.KronLLLaplace, and laplace.DiagLLLaplace, which are all subclasses of laplace.LLLaplace. All of these can be conveniently accessed via the laplace.Laplace function.
  2. The backends in laplace.curvature which provide access to Hessian approximations of the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for decomposing a neural network into feature extractor and last layer for LLLaplace subclasses (laplace.feature_extractor) and effectively dealing with Kronecker factors (laplace.matrix).

Extendability

To extend the laplace package, new BaseLaplace subclasses can be designed, for example, a block-diagonal structure or subset-of-weights Laplace. Alternatively, extending or integrating backends (subclasses of curvature.curvature) allows to provide different Hessian approximations to the Laplace approximations. For example, currently the curvature.BackPackInterface based on BackPACK and curvature.AsdlInterface based on ASDL are available. The curvature.AsdlInterface provides a Kronecker factored empirical Fisher while the curvature.BackPackInterface does not, and only the curvature.BackPackInterface provides access to Hessian approximations for a regression (MSELoss) loss function.

Example usage

Post-hoc prior precision tuning of last-layer LA

In the following example, a pre-trained model is loaded, then the Laplace approximation is fit to the training data, and the prior precision is optimized with cross-validation 'CV'. After that, the resulting LA is used for prediction with the 'probit' predictive for classification.

from laplace import Laplace

# pre-trained model
model = load_map_model()  

# User-specified LA flavor
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='diag')
la.fit(train_loader)
la.optimize_prior_precision(method='CV', val_loader=val_loader)

# User-specified predictive approx.
pred = la(x, link_approx='probit')

Differentiating the log marginal likelihood w.r.t. hyperparameters

The marginal likelihood can be used for model selection and is differentiable for continuous hyperparameters like the prior precision or observation noise. Here, we fit the library default, KFAC last-layer LA and differentiate the log marginal likelihood.

from laplace import Laplace

# Un- or pre-trained model
model = load_model()  

# Default to recommended last-layer KFAC LA:
la = Laplace(model, likelihood='regression')
la.fit(train_loader)

# ML w.r.t. prior precision and observation noise
ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()

Documentation

The documentation is available here or can be generated and/or viewed locally:

# assuming the repository was cloned
pip install -e .[docs]
# create docs and write to html
bash update_docs.sh
# .. or serve the docs directly
pdoc --http 0.0.0.0:8080 laplace --template-dir template

References

This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

laplace-torch-0.1a1.tar.gz (206.6 kB view details)

Uploaded Source

Built Distribution

laplace_torch-0.1a1-py3-none-any.whl (30.5 kB view details)

Uploaded Python 3

File details

Details for the file laplace-torch-0.1a1.tar.gz.

File metadata

  • Download URL: laplace-torch-0.1a1.tar.gz
  • Upload date:
  • Size: 206.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.5

File hashes

Hashes for laplace-torch-0.1a1.tar.gz
Algorithm Hash digest
SHA256 022c534d707247fd22a23b3ff33fd65020b6f52ff1a461011b8ea25e67094886
MD5 c396f5521d2d3fe1f34b9075bd4df2e8
BLAKE2b-256 8e33bcd054eee803e6c4bea20fbb6faf56e431aef0041e974d9af0c1e693b57a

See more details on using hashes here.

File details

Details for the file laplace_torch-0.1a1-py3-none-any.whl.

File metadata

  • Download URL: laplace_torch-0.1a1-py3-none-any.whl
  • Upload date:
  • Size: 30.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.5

File hashes

Hashes for laplace_torch-0.1a1-py3-none-any.whl
Algorithm Hash digest
SHA256 6088054c0e1d807c0787d8bf548d9b19a84bada0a7da520d3bd4e74e2347c786
MD5 6d7efe696466e06af9ef33c81ccb6f30
BLAKE2b-256 ed6266866e6f08975cfa928addfd8b49b0373eb4c190895e06130f1df0b76ac4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page