Markov chain cubature via JAX.
Project description
MCCube
Markov chain cubature via JAX
MCCube is a JAX library for constructing Markov Chain Cubatures (MCCs) that (weakly) solve certain SDEs, and thus, can be used for performing Bayesian inference.
The core features of MCCube are:
- Approximate Bayesian inference of JAX transformable functions (support for PyTorch, Tensorflow and Numpy functions is provided via Ivy);
- A simple Markov chain cubature inference loop;
- A Component framework for constructing Cubature steps as a composition of a Propagator and Recombinator;
- Trace-time component validation that ensures components obey certain expected mathematical properties, with minimal runtime overhead;
- Visualization tools for evaluating and debugging inference/solving performance;
- A Blackjax-like interface provided by
mccube.extensions
(Coming Soon); - A custom solver for using MCC in Diffrax, also provided by
mccube.extensions
(Coming Soon).
In addition, like the samplers in Blackjax, MCCube can easily be integrated with probabilistic programming languages (PPLs), as long as they can provide a (potentially unnormalized) log-density function.
[!warning]
This package is currently a work-in-progress/experimental. Expect bugs, API instability, and treat all results with a healthy degree of skepticism.
Who should use MCCube?
MCCube should appeal to:
- Users of Blackjax (people who need/want modular GPU/TPU capable samplers);
- Users of Diffrax (people who need to solve SDEs/CDEs);
- Markov chain cubature researchers/developers.
Installation
To install the base pacakge:
pip install mccube
If you want all the extras provided in mccube.extensions
:
pip install mccube[extras]
Requires Python 3.9+, JAX 0.4.11+, and Equinox 0.10.5+.
By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA compatible accelerators (GPUs/TPUs) please follow these installation instructions. Windows support for JAX is currently experimental; WSL2 is the recommended approach for using JAX on Windows.
Documentation
Coming soon at https://mccube.readthedocs.io/.
What is Markov chain cubature?
MCC is an approach to constructing a Cubature on Wiener Space which does not suffer from exponential scaling in time (particle count explosion), thanks to the utilization of (partitioned) recombination in the Cubature step/transition kernel.
Quick Example
The below toy example demonstrates MCCube for inferring the moments of a ten dimensional Gaussian, with mean two and diagonal covariance six, given its logdensity function. More in-depth examples are coming soon.
import jax
import numpy as np
from jax.scipy.stats import multivariate_normal
from mccube import MCCubatureStep, mccubaturesolve, minimal_cubature_formula
from mccube.components import LangevinDiffusionPropagator, MonteCarloRecombinator
from mccube.metrics import cubature_target_error
# Setup the problem.
n_particles = 8192
target_dimension = 2
rng = np.random.default_rng(42)
prior_particles = rng.uniform(size=(n_particles, target_dimension))
target_mean = 2 * np.ones(target_dimension)
target_cov = 3 * np.diag(target_mean)
# MCCube expects the logdensity to have call signature (t, p(t), args), allowing the
# density to be time dependant, or to rely on some other generic args.
# Note: You can obtain significantly better performance by defining a custom jvp here.
def target_logdensity(t, p, args):
return multivariate_normal.logpdf(p, target_mean, target_cov)
# Setup the MCCubature.
recombinator_key = jax.random.PRNGKey(42)
cfv = minimal_cubature_formula(target_dimension, degree=3).vectors
cs = MCCubatureStep(
propagator=LangevinDiffusionPropagator(cfv),
recombinator=MonteCarloRecombinator(recombinator_key),
)
# Construct the MCCubature/solve for the MCCubature paths.
mccubature_paths = mccubaturesolve(
logdensity=target_logdensity,
transition_kernel=cs,
initial_particles=prior_particles,
)
# Compare mean and covariance of the inferred cubature to the target.
posterior_particles = mccubature_paths.particles[-1, :, :]
mean_err, cov_err = cubature_target_error(posterior_particles, target_mean, target_cov)
print(f"Mean Error: {mean_err}\n", f"Cov Error: {cov_err}")
Note that mccubaturesolve
returns the cubature paths, but does not return any other
intermediate step information. If such information is required, a 'visualizer' callback
can be used, for example:
from mccube.extensions.visualizers import TensorboardVisualizer
with TensorboardVisualizer() as tbv:
cubature = mccubaturesolve(..., visualization_callback=tbv)
To make use of the Tensorboard visualization suite remember to run the following command either during/after each experimental run:
tensorboard --logdir=experiments
Citation
Please cite this repository if it has been useful in your work:
@software{mccube2023github,
author={},
title={{MCC}ube: Markov chain cubature via {JAX}},
url={},
version={<insert current release tag>},
year={2023},
}
See Also
Some other Python/JAX packages that you may find interesting:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.