Skip to main content

ADVI library for generelized graphical models

Project description

ADVI in JAX

Design considerations

  • ADVI class is an object but ADVI.objective_fun is a pure function that can be optimized with optax or jaxopt or any other jax supported optimizers.
  • variational distribution parameters can be initialized with ADVI.init using a distrax or tfp distribution as an initializer (or any jax distribution that implements .sample() method in a similar way).
  • Users can pass the suitable bijectors of class distrax.Bijector to the variational distribution.
  • Transformation is directly applied to posterior and thus prior and likelihood stay untouched during the entire process. This way, after the training, the variational distribution is ready for sampling without any additional transformations. Also, this gives freedom to variational distribution to be constructed in more complex way as it is separated from the other parts of the model (see the example below).
  • If we do not change the key during the training, the method is called the deterministic ADVI.
  • Users can implement their own likelihood_log_prob_fun because likelihood does not necessarily have to be a distribution.

A Coin Toss Example

import jax
import jax.numpy as jnp

from advi_jax import ADVI
from advi_jax.variational_distributions import MeanField
from advi_jax.init import initialize
import tensorflow_probability.substrates.jax as tfp
dist = tfp.distributions

# Data
tosses = jnp.array([0, 1, 0, 0, 1, 0])

# Prior and likelihood
prior_dist = dist.Beta(2.0, 3.0)
likelihood_log_prob_fun = lambda theta: dist.Bernoulli(probs=theta).log_prob(tosses).sum()

# ADVI model
model = ADVI(prior_dist, likelihood_log_prob_fun, tosses)

# Variational distribution and bijector
bijector = distrax.Sigmoid()
variational_dist = MeanField(u_mean = jnp.array(0.0), u_scale = jnp.array(0.0), bijector = bijector)

# Initialize the parameters of variational distribution
key = jax.random.PRNGKey(0)
variational_dist = initialize(key, variational_dist, initializer=dist.Normal(0.0, 1.0))

# Define the value and grad function
value_and_grad_fun = jax.jit(jax.value_and_grad(model.objective_fun, argnums=1), static_argnums=2)

# Do gradient descent!
learning_rate = 0.01
for i in range(100):
    key = jax.random.PRNGKey(i)  # If this is constant, this becomes deterministic ADVI
    loss_value, grads = value_and_grad_fun(key, variational_dist, n_samples=10)
    variational_dist = variational_dist - learning_rate * grad

# Get the posterior samples
key = jax.random.PRNGKey(2)
posterior_samples = variational_dist.sample(seed=key, sample_shape=(100,))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

advi_jax-0.0.1.tar.gz (1.2 MB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page