Skip to main content

Approximate Bayesian Inference in JAX

Project description

BIJAX

Bayesian Inference in JAX.

Installation

pip install git+https://github.com/patel-zeel/bijax.git

Methods implemented in BIJAX

  • from bijax.advi import ADVI - Automatic Differentiation Variational Inference
  • [WIP]from bijax.laplace import ADLaplace - Automatic Differentiation Laplace approximation.
  • from bijax.mcmc import MCMC - A helper class for external Markov Chain Monte Carlo (MCMC) sampling.

How to use BIJAX?

BIJAX is built without layers of abstractions or proposing new conventions. Thus, it is also useful for educational purposes. If you like to directly dive into the examples, please refer to the examples directory.

There are a few core components of bijax:

Prior

tensoflow_probability.substrates.jax should be used to define the distributions for prior.

import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions

Prior distribution for the coin toss problem can be defined as follows:

prior = {"p_of_heads": tfd.Beta(0.5, 0.5)}

Prior distribution for the Linear Regression problem can be defined as follows:

shape_of_weights = 5
prior = {"weights": tfd.MultivariateNormalDiag(
                                               loc=tf.zeros(shape_of_weights), 
                                               scale_diag=tf.ones(shape_of_weights)
                                            )}

Bijectors

Bijectors available in tensorflow_probability.substrates.jax are used to facilitate the change of variable trick or change of support trick. Here, a bijector should transform a Gaussian random variable with infinite support to a transformed random variable with finite support.

import tensorflow_probability.substrates.jax as tfp
tfb = tfp.bijectors

To perform Automatic Differentiation Variational Inference for the coin toss problem, a bijector can be defined as follows:

prior = {"p_of_heads": tfd.Beta(0.5, 0.5)}
bijector = {"p_of_heads": tfb.Sigmoid()}

For the Linear Regression problem, a bijector can be defined as follows:

shape_of_weights = 5
prior = {"weights": tfd.MultivariateNormalDiag(
                                               loc=tf.zeros(shape_of_weights), 
                                               scale_diag=tf.ones(shape_of_weights)
                                            )}
bijector = {"weights": tfb.Identity()}

Likelihood

Users have total freedom on how to define the log likelihood function adhering to several conditions. The log likelihood function should take the following arguments:

  • latent_sample: a dictionary of values that represents a sample taken from the latent (prior) parameter distributions. It will have same keys as the prior.
  • outputs: Outputs generated from the likelihood. We will find log probability of the outputs given a latent sample.
  • inputs: Input data required to evaluate the likelihood. For example, in the Linear Regression problem, X is inputs. For the coin toss problem, inputs is None.
  • kwargs: We internally pass the trainable params as kwargs to the likelihood function. So, the user can mention additional learnable parameters in kwargs and they will be trained.

For coin toss problem, we can define the log likelihood function as follows:

def log_likelihood_fn(latent_sample, outputs, inputs, **kwargs):
    p_of_heads = latent_sample["p_of_heads"]
    log_likelihood = tfd.Bernoulli(probs=p_of_heads).log_prob(outputs).sum()
    return log_likelihood

For the Linear Regression problem with learnable noise variance, we can define the log likelihood function as follows:

def log_likelihood_fn(latent_sample, outputs, inputs, **kwargs):
    weights = latent_sample["weights"]
    loc = jnp.dot(weights, inputs["X"])
    noise_variance = jnp.exp(kwargs["log_noise_scale"])
    log_likelihood = tfd.MultivariateNormalDiag(loc=loc, scale_diag=noise_variance).log_prob(outputs).sum()
    return log_likelihood

Initialization

We can automatically initialize the parameters of the model.

Here is an example with ADVI model.

model = ADVI(prior, bijector, log_likelihood_fn, vi_type="mean_field")
seed = jax.random.PRNGKey(0)
params = model.init(seed)

Optimization

Models in bijax have loss_fn method which can be used to compute the loss. The loss can be optimized with any method that work with JAX. We also have a utility function from bijax.utils import train to train the model using optax optimizers.

Get the posterior distribution

Some of the models (ADVI and ADLaplace) support .apply() method to get the posterior distribution.

posterior = model.apply(params, ...)
posterior.sample(...)
posterior.log_prob(...)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bijax-0.1.1.tar.gz (299.9 kB view details)

Uploaded Source

File details

Details for the file bijax-0.1.1.tar.gz.

File metadata

  • Download URL: bijax-0.1.1.tar.gz
  • Upload date:
  • Size: 299.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for bijax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c508f6f43206e56f30cdf7d5863ef0f5a000d14caf985f7a387dd07652ce64d5
MD5 48e5b85e5a345aab1ebc7846da7bee46
BLAKE2b-256 22f3983e9e4ff77b24fcdf97160006d45225cb39d0d3f84ecfaca69e9375135c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page