Skip to main content

An optimizer for neural networks based on variational learning

Project description

Improved Variational Online Newton (IVON)

Downloads Downloads

We provide PyTorch code of the IVON optimizer to train deep neural networks, along with a usage guide and small-scale examples. For the experiments from the paper, see the ivon-experiments repository. An experimental implementation in JAX as an optax-optimizer can be found here.

Variational Learning is Effective for Large Deep Networks
Y. Shen*, N. Daheim*, B. Cong, P. Nickl, G.M. Marconi, C. Bazan, R. Yokota, I. Gurevych, D. Cremers, M.E. Khan, T. Möllenhoff
International Conference on Machine Learning (ICML), 2024 (spotlight)

ArXiv: https://arxiv.org/abs/2402.17641
Blog: https://team-approx-bayes.github.io/blog/ivon/
Tutorial: https://ysngshn.github.io/research/why-ivon/

Installation of IVON

To install the IVON optimizer run: pip install ivon-opt

Dependencies

Install PyTorch as described here: pip3 install torch --index-url https://download.pytorch.org/whl/cu118

Usage guide

In Appendix A of the paper, we provide practical guidelines for choosing IVON hyperparameters.

Training loop

In the code snippet below we demonstrate the difference in the implementation of the training loop of the IVON optimizer compared to standard optimizers like SGD or Adam. The standard setting for weight sampling during training is to use one MC sample (train_samples=1).

import torch
+import ivon

train_loader = torch.utils.data.DataLoader(train_dataset) 
test_loader = torch.utils.data.DataLoader(test_dataset) 
model = MLP()

-optimizer = torch.optim.Adam(model.parameters())
+optimizer = ivon.IVON(model.parameters(), lr=0.1, ess=len(train_dataset))

for X, y in train_loader:

+    for _ in range(train_samples):
+       with optimizer.sampled_params(train=True)
            optimizer.zero_grad()
            logit = model(X)
            loss = torch.nn.CrossEntropyLoss(logit, y)
            loss.backward()

    optimizer.step()

Prediction

There are two different ways of using the variational posterior of IVON for prediction:

(1) IVON can be used like standard optimizers:

for X, y in test_loader:
    logit = model(X)
    _, prediction = logit.max(1)

(2) A better way is to do posterior averaging. We can draw a total of test_samples weights from a Gaussian, predict with each one and average the predictions to obtain predictive probabilities.

for X, y in test_loader:
    sampled_probs = []
    for i in range(test_samples):
        with optimizer.sampled_params():
            sampled_logit = model(X)
            sampled_probs.append(F.softmax(sampled_logit, dim=1))
    prob = torch.mean(torch.stack(sampled_probs), dim=0)
    _, prediction = prob.max(1)

Examples

We include three Google Colab notebooks to demonstrate the usage of the IVON optimizers on small-scale problems.

  1. MNIST image classification
    • We compare IVON to an SGD baseline.
  2. 1-D Regression
    • IVON captures uncertainty in regions with little data. AdamW fails at this task.
  3. 2-D Logistic Regression
    • SGD finds the mode of the weight posterior, while IVON converges to a region that is more robust to perturbation.

How to cite

@inproceedings{shen2024variational,
      title={Variational Learning is Effective for Large Deep Networks}, 
      author={Yuesong Shen and Nico Daheim and Bai Cong and Peter Nickl and Gian Maria Marconi and Clement Bazan and Rio Yokota and Iryna Gurevych and Daniel Cremers and Mohammad Emtiyaz Khan and Thomas Möllenhoff},
      booktitle={International Conference on Machine Learning (ICML)},
      year={2024},
      url={https://arxiv.org/abs/2402.17641}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ivon_opt-0.1.3.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

ivon_opt-0.1.3-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file ivon_opt-0.1.3.tar.gz.

File metadata

  • Download URL: ivon_opt-0.1.3.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for ivon_opt-0.1.3.tar.gz
Algorithm Hash digest
SHA256 586e66aac197e6a00deba852860412f58b6d06cd8bf3413d4c78ddb321f08c81
MD5 ade2c50ba7fb6275e6655afa2c1b8b82
BLAKE2b-256 2911e0cb5ed999dd070862863adbc46ee4193cadd1480eec2cd081b5fa06d4d9

See more details on using hashes here.

File details

Details for the file ivon_opt-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: ivon_opt-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for ivon_opt-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2d528953eea524fbd92b375649fefb032ae28c45b171eaca590aef696537b27e
MD5 b0ffee6cb89e8dce50f9902fc6e6c7e3
BLAKE2b-256 9f252b1e477cf53bfa47926a18288dee5190474b9cd46f70c11048ac9efc6ff6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page