Skip to main content

laplace - Laplace approximations for deep learning

Project description

The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.

There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:

@inproceedings{laplace2021,
  title={Laplace Redux--Effortless {B}ayesian Deep Learning},
  author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
          and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
  booktitle={{N}eur{IPS}},
  year={2021}
}

The code to reproduce the experiments in the paper is also publicly available; it provides examples of how to use our library for predictive uncertainty quantification, model selection, and continual learning.

[!IMPORTANT] As a user, one should not expect Laplace to work automatically. That is, one should experiment with different Laplace's options (hessian_factorization, prior precision tuning method, predictive method, backend, etc!). Try looking at various papers that use Laplace for references on how to set all those options depending on the applications/problems at hand.

Installation

[!IMPORTANT] We assume Python >= 3.9 since lower versions are (soon to be) deprecated. PyTorch version 2.0 and up is also required for full compatibility.

To install laplace with pip, run the following:

pip install laplace-torch

Additionally, if you want to use the asdfghjkl backend, please install it via:

pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl

Simple usage

[!TIP] Check out https://aleximmer.github.io/Laplace for more usage examples and API reference.

In the following example, a pre-trained model is loaded, then the Laplace approximation is fit to the training data (using a diagonal Hessian approximation over all parameters), and the prior precision is optimized with cross-validation "gridsearch". After that, the resulting LA is used for prediction with the "probit" predictive for classification.

from laplace import Laplace

# Pre-trained model
model = load_map_model()

# User-specified LA flavor
la = Laplace(model, "classification",
             subset_of_weights="all",
             hessian_structure="diag")
la.fit(train_loader)
la.optimize_prior_precision(
    method="gridsearch",
    pred_type="glm",
    link_approx="probit",
    val_loader=val_loader
)

# User-specified predictive approx.
pred = la(x, pred_type="glm", link_approx="probit")

Contributing

Pull requests are very welcome. Please follow the guidelines in https://aleximmer.github.io/Laplace/devs_guide

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

laplace_torch-0.2.2.1.tar.gz (127.1 kB view details)

Uploaded Source

Built Distribution

laplace_torch-0.2.2.1-py3-none-any.whl (4.0 kB view details)

Uploaded Python 3

File details

Details for the file laplace_torch-0.2.2.1.tar.gz.

File metadata

  • Download URL: laplace_torch-0.2.2.1.tar.gz
  • Upload date:
  • Size: 127.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.20.0.post1 CPython/3.9.19 Darwin/24.1.0

File hashes

Hashes for laplace_torch-0.2.2.1.tar.gz
Algorithm Hash digest
SHA256 90b9725ac032b23a7d2e1a15698f86b1a8eff9b1714dc9f4e59d73e096ebd1a2
MD5 70ff1f8e2c452c0ef6b9917d0719b5d0
BLAKE2b-256 b2545108070b39f3f643c732cbf76eb00bca6f681b11e3b15adc0e5171e1e056

See more details on using hashes here.

File details

Details for the file laplace_torch-0.2.2.1-py3-none-any.whl.

File metadata

  • Download URL: laplace_torch-0.2.2.1-py3-none-any.whl
  • Upload date:
  • Size: 4.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.20.0.post1 CPython/3.9.19 Darwin/24.1.0

File hashes

Hashes for laplace_torch-0.2.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b285fb0036ff7c788564a8a0dd3ad6bb5508ab9847f4ea998355c615d52d81a5
MD5 93261e2a78fffe9ab4064aed7cff08b8
BLAKE2b-256 048da9d18ac8465d1542161defe8c2efd7e3c354329dde0ee10e0dcc4b14017a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page