Skip to main content

A JAX library to run VI.

Project description

Vijax

vijax is a flexible and modular library for variational inference (VI) in Python. The library is designed to be accessible to a wide range of users and applications, enabling them to perform VI without getting tied down by specific abstractions.

The key components of vijax are based on a set of abstract base classes and interfaces, which provide a consistent and adaptable structure for building custom VI algorithms. Users can easily extend and modify the library to suit their specific needs, as long as they follow the abstractions provided.

In this document, we present the API design for the main components of vijax:

  1. Model: Represents a probability model $p(z,x)$, where $z$ are latent variables and $x$ are observed data.
  2. VarDist: Represents a variational distribution $q_w(z)$.
  3. Recipe: Provides a high-level interface for running pre-defined VI algorithms.

By adhering to these abstractions, users can easily plug in their own optimization routines, models, and variational distributions, while still benefiting from the core features and utilities provided by the vijax library.

Model

A Model is an abstract base class that represents a probability model.

A concrete implementation of a Model must support the following:

  • model = Model(*args,**vargs) - the constructor can use any argument structure

  • model.ndim: an integer representing the number of dimensions of the model.

  • model.log_prob(z): a required method that evaluates the log-probability of a single ndim length vector z.

A concrete implementation of a Model can also optionally support:

  • model.sample_prior(PRNGKey): a method that samples a single ndim length vector from the prior distribution of the model over latent variables in unconstrained space. The method takes a JAX PRNG key as input and returns the sampled vector.

  • model.reference_samples(nsamps): a method that samples a set of nsamps samples from the posterior / model distribution. This may be implemented by (1) exploiting a special algorithm to exactly sample from the posterior (2) running MCMC (3) looking up samples in a database.

  • model.constrain(z): a method that transforms a single given unconstrained latent vector z to constrained space.

VarDist

A VarDist is a class that represents a variational distribution $q(z|x)$. A VarDist must support the following:

  • vardist = VarDist(ndim, *args, **vargs) - initialization must take ndim as the first argument but is otherwise class dependent. One might provide defaults for all arguments other than ndim.

  • vardist.ndim - the number of dimensions of the unconstrained latent variables.

  • vardist.initial_params() - initializes the variational parameters (in unconstrained space.)

  • vardist.sample(params, key) - get a single ndim length vector from the variational distribution. Must be differentiable w.r.t. params.

  • vardist.log_prob(params, z) - evaluate log probability of a single vector z

A concrete instance of VariationalDistribution can optionally support:

  • vardist.sample_and_log_prob(params, key) - generate a sample and evaluate log probability at the same time (can be more efficient and stable for some distributions)

  • vardist.sample_and_log_prob_stl(params, key) - STL(Sticking the Landing) compatible version of the vardist.sample_and_log_prob (not expected to be more efficient in general, but should be more numerically stable)

  • vardist.mean_cov(params) - get the (closed-form) mean and covariance for this variational distribution

  • vardist_new, params = vardist.match_mean_inv_cov(mean,inv_cov) - get parameters and a new distribution object that match a given mean and inverse covariance

  • vardist_new, params = vardist.match_mean_cov(mean,cov) - get parameters and a new distribution object that match a given mean and covariance

Recipe

A Recipe is a class that will "do inference". It must support three methods

  • recipe = Recipe(*args,**vargs) - Initialization can use any argument structure
  • new_vardist, new_params, results = recipe.run(model, vardist, params) - actually run the recipe
    • The first two return arguments are obvious. The last is any recipe-dependent structure that can contain information about convergence, time used, etc.
  • z = recipe.sample(model, vardist, params, key) - draw a sample from the recipe for this model and variational distribution
    • In many cases this would just return vardist.sample(params,key) but in cases of things like importance weighted objectives, could be more complex.
  • l = recipe.objective(model, vardist, params, key) - estimate the recipe's objective for a given set of parameters

Example

Here's an example of how the API might work.

# Get model and variational distribution
model = models.Funana(3)

# Create an instance of the variational distribution
gausssian_q = vardists.Gaussian(3)

# Initialize the parameters of the variational distribution
gaussian_w = gausssian_q.initial_params()

# Create an instance of the recipe
recipe = recipes.SimpleVI(maxiter=10, batchsize=128)

# Run the recipe for variational inference
new_q, new_w, vi_rez = recipe.run(target=model, vardist=gaussian_q, params=gaussian_w)

# Run the recipe for variational inference with a flow variational distribution
flow_q = vardists.RealNVP(
    3,
    num_transformations=6,
    num_hidden_units=16,
    num_hidden_layers=2,
    params_init_scale=0.001
)
# Initialize the parameters of the flow variational distribution
flow_w = flow_q.initial_params()

# Run the same recipe with a flow variational distribution
new_q, new_w, vi_rez = recipe.run(target=model, vardist=flow_q, params=flow_w)

Implemented Models

This repository includes several probabilistic models that can be used for variational inference. Below is a table describing each model along with a reference to their implementation files.

Model Name Description
Well-Conditioned Gaussian A Gaussian model with well-conditioned covariance structure. Reference
Ill-Conditioned Gaussian A Gaussian model with ill-conditioned covariance structure. Reference
Neal's Funnel A model with a funnel-shaped distribution, often used to test sampling algorithms. Reference
Banana A model with a banana-shaped distribution, used to test inference algorithms. Reference
Funana A custom model that combines the densities of Neal's Funnel and Banana. Reference
Studentt-1.5 A multivariate Student-t distribution with 1.5 degrees of freedom. Reference
Studentt-2.5 A multivariate Student-t distribution with 2.5 degrees of freedom. Reference

Implemented Variational Distributions

This repository includes several variational distributions that can be used for variational inference. Below is a table describing each distribution along with a reference to their implementation files.

Distribution Name Description
Gaussian A standard Gaussian variational distribution. Reference
Diagonal Gaussian A Gaussian distribution with a diagonal covariance matrix. Reference
RealNVP A flow-based variational distribution using RealNVP transformations. Reference

Requirements

This project requires Python 3.9 or higher. The dependencies for this project are managed using Conda and are listed in the environment.yml file. The main libraries used in this project include:

  • jax: A library for high-performance machine learning research.
  • jaxlib: The JAX library containing the XLA compiler and other dependencies.
  • numpyro: A probabilistic programming library built on JAX.
  • tensorflow-probability: A library for probabilistic reasoning and statistical analysis.
  • inference-gym: A suite of probabilistic models for benchmarking inference algorithms.

Setting Up the Environment

To set up the environment, follow these steps:

  1. Install Conda: If you don't have Conda installed, download and install it.

  2. Create the Conda Environment: Use the environment.yml file to create a new Conda environment. Run the following command in your terminal:

    conda env create -f environment.yml
    
  3. Activate the Environment: Once the environment is created, activate it using:

    conda activate vi_jax
    
  4. Run the Tests: To ensure everything is set up correctly, run the test file test_models_and_vardists.py:

    python test_models_and_vardists.py
    

This will execute the tests and verify that all models and variational distributions are working as expected with the specified recipes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vijax-0.0.1.tar.gz (24.8 kB view details)

Uploaded Source

Built Distribution

vijax-0.0.1-py3-none-any.whl (26.5 kB view details)

Uploaded Python 3

File details

Details for the file vijax-0.0.1.tar.gz.

File metadata

  • Download URL: vijax-0.0.1.tar.gz
  • Upload date:
  • Size: 24.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.3

File hashes

Hashes for vijax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 646f45a1566afc3deb6095b09117668af6f7cebcdad45ed54369eea05cf91ed4
MD5 2ae953f275914617953150ec4d7f7ac5
BLAKE2b-256 a020abcab333f9f82db08e6175e7cb3ecf9643fdcb35003f8bf5377a938d76be

See more details on using hashes here.

File details

Details for the file vijax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: vijax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 26.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.3

File hashes

Hashes for vijax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c4ae5c9a06802ac23c700d06c97c1e71d12b70792dcf9cebc5b6cb2d2c19147e
MD5 f0b61934d32060e9b5095ca6794b7b4f
BLAKE2b-256 7cdc9931da7c37838f5a3e52e09adb47116e802f3cfd6ae3af856e56eb6ca207

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page