Skip to main content

A very basic implementation of SVGD, based on https://github.com/dilinwang820/Stein-Variational-Gradient-Descent

Project description

simpleSVGD

This package is a tiny SVGD algorithm specifically developed to operate on distributions found in HMCLab.

By default, this package uses radial basis functions to compute sample interaction and AdaGrad to optimize the samples.

Binder

Installation:

We recommend using at least Python 3.7.

To get the latest release, simply use pip inside your favourite environment:

pip install simpleSVGD

To install the latest version directly from GitHub:

git clone git@github.com:larsgeb/simpleSVGD.git
cd simpleSVGD
pip install -e .

Mini-tutorial

This package can be used with minimal development. The only thing one needs to supply to the algorithm is:

  1. The gradient of the function to optimize, gradient_fn(samples). The function itself is not needed.
  2. An initial collection of samples initial_samples, a numpy.array. It helps if these are close to the target function/distribution.

Input/output of your gradient_fn

It is essential to get the input/output shapes of the target (gradient) right. As input, it should take an arbitrary amount of samples, with the appropriate dimensionality. This means if ones wants 430 samples on a 3 dimensional function, the input/output shapes looks like this:

output_gradient = gradient_fn(input_samples)

input_samples.shape = (430, 3)
output_gradient.shape = (430, 3)

Typically, it is useful to instantiate the samples using a Normal distribution. Using NumPy, this is done with:

import numpy as np
np.random.seed(235)

mean = 0
standard_dev = 1
n_samples = 100
dimensions = 2

initial_samples = np.random.normal(mean, standard_dev, [n_samples, dimensions])

Defining an example target

A good 2-dimensional test function would be the Himmelblau function:

def Himmelblau(input_array: np.array) -> np.array:

    # As this is a 2-dimensional function, assert that the passed input_array
    # is correct.
    assert input_array.shape[1] == 2

    # To simplify reading this function, we do this step in between. It is not
    # the most optimal way to program this.
    x = input_array[:, 0, None]
    y = input_array[:, 1, None]

    output_array = (x ** 2 + y - 11) ** 2 + (x + y ** 2 - 7) ** 2

    # As the output should be a scalar function, assert that the
    # output is also length 1 in dim 2 of the array.
    assert output_array.shape == (input_array.shape[0], 1)

    smoothing = 100
    return output_array / smoothing

and its gradient:

def Himmelblau_grad(input_array: np.array) -> np.array:

    # As this is a 2-dimensional function, assert that the passed input_array
    # is correct.
    assert input_array.shape[1] == 2

    # To simplify reading this function, we do this step in between. It is not
    # the most optimal way to program this.
    x = input_array[:, 0, None]
    y = input_array[:, 1, None]

    # Compute partial derivatives and combine them
    output_array_dx = 2 * (x ** 2 + y - 11) * (2 * x) + 2 * (x + y ** 2 - 7)
    output_array_dy = 2 * (x ** 2 + y - 11) + 2 * (x + y ** 2 - 7) * (2 * y)
    output_array = np.hstack((output_array_dx, output_array_dy))

    # Check if the output shape is correct
    assert output_array.shape == input_array.shape

    smoothing = 100
    return output_array / smoothing

Running the algorithm

To run the algorithm with a 1000 samples that are initially Normally distribution (mean=0, standard deviation=3, parameters chosen based on prior belief), we simply call simpleSVGD.update() in the following way:

initial_samples = np.random.normal(0, 3, [1000, 2])

#%matplotlib notebook

figure = plt.figure(figsize=(6, 6))
plt.xlabel("Parameter 0")
plt.ylabel("Parameter 1")
plt.title("SVGD animation on the Himmelblau function")

final_samples = simpleSVGD.update(
    initial_samples,
    Himmelblau_grad,
    n_iter=130,
    # AdaGrad parameters
    stepsize=1e-1,
    alpha=0.9,
    fudge_factor=1e-3,
    historical_grad=1,
    #animate=True,
    #background=background,
    #figure=figure,
)

To animate the algorithm, simply uncomment the comments. The result should be similar to this:

https://user-images.githubusercontent.com/21038893/151603377-a473e7b1-f7b4-417b-a685-9c0cfa98dc15.mov

The origins of SVGD

SVGD is a general purpose variational inference algorithm that forms a natural counterpart of gradient descent for optimization. SVGD iteratively transports a set of particles to match with the target distribution, by applying a form of functional gradient descent that minimizes the KL divergence.

For more information, please visit the original implementers project website - SVGD, or their publication; Qiang Liu and Dilin Wang. Stein Variational Gradient Descent (SVGD): A General Purpose Bayesian Inference Algorithm. NIPS, 2016.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simpleSVGD-0.2.tar.gz (23.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

simpleSVGD-0.2-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file simpleSVGD-0.2.tar.gz.

File metadata

  • Download URL: simpleSVGD-0.2.tar.gz
  • Upload date:
  • Size: 23.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.11

File hashes

Hashes for simpleSVGD-0.2.tar.gz
Algorithm Hash digest
SHA256 2f99c1c9115fd6129829771a68a727d9e02a2a04de1decd4eb1a565516cdf392
MD5 a092dc41fe6685f36f78be40d708dafa
BLAKE2b-256 1015559d15d4ce7af2b315a53144e0794e8dac22497a04b20817283c3a89919d

See more details on using hashes here.

File details

Details for the file simpleSVGD-0.2-py3-none-any.whl.

File metadata

  • Download URL: simpleSVGD-0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.7.11

File hashes

Hashes for simpleSVGD-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2c7ab0cf758007ae12d9e73fb299e1f6fe2e94e80bf5f5575ef8b7ed4e1205d0
MD5 a53afdb3d898092726bc801023fd6585
BLAKE2b-256 acf99ffa1f1bc25822fcdc0310f0b5efdf9fd255a7de9b7137948d1eca171209

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page