Skip to main content

ADVI library for generelized graphical models

Project description

ADVI in JAX

Design considerations

  • ADVI class is an object but ADVI.objective_fun is a pure function that can be optimized with optax or jaxopt or any other jax supported optimizers.
  • variational distribution parameters can be initialized with ADVI.init using a distrax or tfp distribution as an initializer (or any jax distribution that implements .sample() method in a similar way).
  • Users can pass the suitable bijectors of class distrax.Bijector to the variational distribution.
  • Transformation is directly applied to posterior and thus prior and likelihood stay untouched during the entire process. This way, after the training, the variational distribution is ready for sampling without any additional transformations. Also, this gives freedom to variational distribution to be constructed in more complex way as it is separated from the other parts of the model (see the example below).
  • If we do not change the key during the training, the method is called the deterministic ADVI.
  • Users can implement their own likelihood_log_prob_fun because likelihood does not necessarily have to be a distribution.

A Coin Toss Example

import jax
import jax.numpy as jnp

from advi_jax import ADVI
from advi_jax.variational_distributions import MeanField
from advi_jax.init import initialize
import tensorflow_probability.substrates.jax as tfp
dist = tfp.distributions

# Data
tosses = jnp.array([0, 1, 0, 0, 1, 0])

# Prior and likelihood
prior_dist = dist.Beta(2.0, 3.0)
likelihood_log_prob_fun = lambda theta: dist.Bernoulli(probs=theta).log_prob(tosses).sum()

# ADVI model
model = ADVI(prior_dist, likelihood_log_prob_fun, tosses)

# Variational distribution and bijector
bijector = distrax.Sigmoid()
variational_dist = MeanField(u_mean = jnp.array(0.0), u_scale = jnp.array(0.0), bijector = bijector)

# Initialize the parameters of variational distribution
key = jax.random.PRNGKey(0)
variational_dist = initialize(key, variational_dist, initializer=dist.Normal(0.0, 1.0))

# Define the value and grad function
value_and_grad_fun = jax.jit(jax.value_and_grad(model.objective_fun, argnums=1), static_argnums=2)

# Do gradient descent!
learning_rate = 0.01
for i in range(100):
    key = jax.random.PRNGKey(i)  # If this is constant, this becomes deterministic ADVI
    loss_value, grads = value_and_grad_fun(key, variational_dist, n_samples=10)
    variational_dist = variational_dist - learning_rate * grad

# Get the posterior samples
key = jax.random.PRNGKey(2)
posterior_samples = variational_dist.sample(seed=key, sample_shape=(100,))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

advi_jax-0.0.1.tar.gz (1.2 MB view details)

Uploaded Source

File details

Details for the file advi_jax-0.0.1.tar.gz.

File metadata

  • Download URL: advi_jax-0.0.1.tar.gz
  • Upload date:
  • Size: 1.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.10.4

File hashes

Hashes for advi_jax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 006a616b270bd0a435e406a24627f02af69a98dc14303ab7582fbb2ce606f5ce
MD5 e771f3d017c1bd4d3c2c008ce985c198
BLAKE2b-256 7b2704fe3a570e2b8315435135282a19111007fc8ce36d414a2f6d40024390d8

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page