ADVI library for generelized graphical models
Project description
ADVI in JAX
Design considerations
- ADVI class is an object but
ADVI.objective_fun
is a pure function that can be optimized withoptax
orjaxopt
or any other jax supported optimizers. - variational distribution parameters can be initialized with
ADVI.init
using adistrax
ortfp
distribution as an initializer (or any jax distribution that implements.sample()
method in a similar way). - Users can pass the suitable bijectors of class
distrax.Bijector
to the variational distribution. - Transformation is directly applied to posterior and thus prior and likelihood stay untouched during the entire process. This way, after the training, the variational distribution is ready for sampling without any additional transformations. Also, this gives freedom to variational distribution to be constructed in more complex way as it is separated from the other parts of the model (see the example below).
- If we do not change the
key
during the training, the method is called the deterministic ADVI. - Users can implement their own
likelihood_log_prob_fun
because likelihood does not necessarily have to be a distribution.
A Coin Toss Example
import jax
import jax.numpy as jnp
from advi_jax import ADVI
from advi_jax.variational_distributions import MeanField
from advi_jax.init import initialize
import tensorflow_probability.substrates.jax as tfp
dist = tfp.distributions
# Data
tosses = jnp.array([0, 1, 0, 0, 1, 0])
# Prior and likelihood
prior_dist = dist.Beta(2.0, 3.0)
likelihood_log_prob_fun = lambda theta: dist.Bernoulli(probs=theta).log_prob(tosses).sum()
# ADVI model
model = ADVI(prior_dist, likelihood_log_prob_fun, tosses)
# Variational distribution and bijector
bijector = distrax.Sigmoid()
variational_dist = MeanField(u_mean = jnp.array(0.0), u_scale = jnp.array(0.0), bijector = bijector)
# Initialize the parameters of variational distribution
key = jax.random.PRNGKey(0)
variational_dist = initialize(key, variational_dist, initializer=dist.Normal(0.0, 1.0))
# Define the value and grad function
value_and_grad_fun = jax.jit(jax.value_and_grad(model.objective_fun, argnums=1), static_argnums=2)
# Do gradient descent!
learning_rate = 0.01
for i in range(100):
key = jax.random.PRNGKey(i) # If this is constant, this becomes deterministic ADVI
loss_value, grads = value_and_grad_fun(key, variational_dist, n_samples=10)
variational_dist = variational_dist - learning_rate * grad
# Get the posterior samples
key = jax.random.PRNGKey(2)
posterior_samples = variational_dist.sample(seed=key, sample_shape=(100,))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
advi_jax-0.0.1.tar.gz
(1.2 MB
view details)
File details
Details for the file advi_jax-0.0.1.tar.gz
.
File metadata
- Download URL: advi_jax-0.0.1.tar.gz
- Upload date:
- Size: 1.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 006a616b270bd0a435e406a24627f02af69a98dc14303ab7582fbb2ce606f5ce |
|
MD5 | e771f3d017c1bd4d3c2c008ce985c198 |
|
BLAKE2b-256 | 7b2704fe3a570e2b8315435135282a19111007fc8ce36d414a2f6d40024390d8 |