Library to run VI algorithms on Stan models.
Project description
vistan
vistan
is a simple library to run variational inference algorithms on Stan models.
Features
- Initialization: Laplace's method to initialize full-rank Gaussian
- Gradient Estimators: Total-gradient, STL, DReG, closed-form entropy
- Variational Families: Full-rank Gaussian, Diagonal Gaussian, RealNVP
- Objectives: ELBO, IW-ELBO
- IW-sampling: Posterior samples using importance weighting
Installation
pip install vistan
Usage
Meanfield VI
code = """
data {
int<lower=0> J; // number of schools
real y[J]; // estimated treatment effects
real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
real mu; // population treatment effect
real<lower=0> tau; // standard deviation in treatment effects
vector[J] eta; // unscaled deviation from mu by school
}
transformed parameters {
vector[J] theta = mu + tau * eta; // school treatment effects
}
model {
target += normal_lpdf(eta | 0, 1); // prior log-density
target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""
data = {"J": 8,
"y": [28, 8, -3, 7, -1, 1, 18, 12],
"sigma": [15, 10, 16, 11, 9, 11, 10, 18]}
posterior, model, results = vistan.infer(code = code, data = data) # runs Meanfield VI by default
samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()
Gaussian VI
hyperparams = vistan.hyper_params(method = 'gaussian')
posterior, model, results = vistan.infer(code = code, data = data,
hyperparams = hyperparams, verbose = True)
samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()
Flow-based VI
hyperparams = vistan.hyper_params(method = 'flows')
posterior, model, results = vistan.infer(code = code, data = data,
hyperparams = hyperparams, verbose = True)
samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()
ADVI
hyperparams = vistan.hyperparams(method = 'advi')
posterior, model, results = vistan.infer(code = code, data = data,
hyperparams = hyperparams, verbose = True)
samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()
Custom
hyperparams = vistan.hyperparams( method = 'custom',
vi_family = "gaussian",
M_training = 10,
grad_estimator = "DReG",
LI = True)
posterior, model, results = vistan.infer(code = code, data = data,
hyperparams = hyperparams, verbose = True)
samples = posterior.sample(1000)
for i in range(samples['eta'].shape[1]):
plt.plot(samples["eta"][:,i], label = "eta[i]")
plt.show()
Limitations
- We will currently only support inference on all latent parameters in the model
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
vistan-0.0.0.2.tar.gz
(20.8 kB
view details)
Built Distribution
vistan-0.0.0.2-py3-none-any.whl
(34.7 kB
view details)
File details
Details for the file vistan-0.0.0.2.tar.gz
.
File metadata
- Download URL: vistan-0.0.0.2.tar.gz
- Upload date:
- Size: 20.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2.post20201201 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 44969def8069573970eb58be92fca37f7bc4f4b7b50956ee2f18267ceb68dcc6 |
|
MD5 | a1406d79ff4b2865c4d59fb80d112ea8 |
|
BLAKE2b-256 | 309b5cb5c859c34e9e23b0e79b6fef18ae8635fda28b272b054289531a3aeded |
File details
Details for the file vistan-0.0.0.2-py3-none-any.whl
.
File metadata
- Download URL: vistan-0.0.0.2-py3-none-any.whl
- Upload date:
- Size: 34.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2.post20201201 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4456787fffd0cc6779f2282d9d0ece255abb6bc400cd9de938cc988b5b64cfa3 |
|
MD5 | 26617e862d8eb9aad8f90693063e5ce4 |
|
BLAKE2b-256 | eefca800e703b270672ec4cf469e3be7e203a288a25925441ddb3322e28142c7 |