Skip to main content

Library to run VI algorithms on Stan models.

Project description

vistan

vistan is a simple library to run variational inference algorithms on Stan models.

Features

  • Initialization: Laplace's method to initialize full-rank Gaussian
  • Gradient Estimators: Total-gradient, STL, DReG, closed-form entropy
  • Variational Families: Full-rank Gaussian, Diagonal Gaussian, RealNVP
  • Objectives: ELBO, IW-ELBO
  • IW-sampling: Posterior samples using importance weighting

Installation

pip install vistan

Usage

Meanfield VI

code = """
data {
  int<lower=0> J;         // number of schools
  real y[J];              // estimated treatment effects
  real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
  real mu;                // population treatment effect
  real<lower=0> tau;      // standard deviation in treatment effects
  vector[J] eta;          // unscaled deviation from mu by school
}
transformed parameters {
  vector[J] theta = mu + tau * eta;        // school treatment effects
}
model {
  target += normal_lpdf(eta | 0, 1);       // prior log-density
  target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""

data = {"J": 8,
                "y": [28,  8, -3,  7, -1,  1, 18, 12],
                "sigma": [15, 10, 16, 11,  9, 11, 10, 18]}

posterior, model, results = vistan.infer(code = code, data = data) # runs Meanfield VI by default

samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
    plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()

Gaussian VI

hyperparams = vistan.hyper_params(method = 'gaussian')

posterior, model, results = vistan.infer(code = code, data = data, 
                        hyperparams = hyperparams, verbose = True)

samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
    plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()

Flow-based VI

hyperparams = vistan.hyper_params(method = 'flows')

posterior, model, results = vistan.infer(code = code, data = data, 
                        hyperparams = hyperparams, verbose = True)

samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
    plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()

ADVI

hyperparams = vistan.hyperparams(method = 'advi')

posterior, model, results = vistan.infer(code = code, data = data, 
                                hyperparams = hyperparams, verbose = True)

samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
    plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()

Custom

hyperparams = vistan.hyperparams(   method = 'custom', 
                                    vi_family = "gaussian",
                                    M_training = 10,
                                    grad_estimator = "DReG",
                                    LI = True)

posterior, model, results = vistan.infer(code = code, data = data, 
                                hyperparams = hyperparams, verbose = True)

samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
    plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()

Limitations

  • We will currently only support inference on all latent parameters in the model

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vistan-0.0.0.2.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

vistan-0.0.0.2-py3-none-any.whl (34.7 kB view details)

Uploaded Python 3

File details

Details for the file vistan-0.0.0.2.tar.gz.

File metadata

  • Download URL: vistan-0.0.0.2.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2.post20201201 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.9.0

File hashes

Hashes for vistan-0.0.0.2.tar.gz
Algorithm Hash digest
SHA256 44969def8069573970eb58be92fca37f7bc4f4b7b50956ee2f18267ceb68dcc6
MD5 a1406d79ff4b2865c4d59fb80d112ea8
BLAKE2b-256 309b5cb5c859c34e9e23b0e79b6fef18ae8635fda28b272b054289531a3aeded

See more details on using hashes here.

File details

Details for the file vistan-0.0.0.2-py3-none-any.whl.

File metadata

  • Download URL: vistan-0.0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 34.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2.post20201201 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.9.0

File hashes

Hashes for vistan-0.0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 4456787fffd0cc6779f2282d9d0ece255abb6bc400cd9de938cc988b5b64cfa3
MD5 26617e862d8eb9aad8f90693063e5ce4
BLAKE2b-256 eefca800e703b270672ec4cf469e3be7e203a288a25925441ddb3322e28142c7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page