Skip to main content

Markov chain cubature via JAX.

Project description

MCCube logo

MCCube
Markov chain cubature via JAX

MCCube is a JAX library for constructing Markov Chain Cubatures (MCCs) that (weakly) solve certain SDEs, and thus, can be used for performing Bayesian inference.

The core features of MCCube are:

  • Approximate Bayesian inference of JAX transformable functions (support for PyTorch, Tensorflow and Numpy functions is provided via Ivy);
  • A simple Markov chain cubature inference loop;
  • A Component framework for constructing Cubature steps as a composition of a Propagator and Recombinator;
  • Trace-time component validation that ensures components obey certain expected mathematical properties, with minimal runtime overhead;
  • Visualization tools for evaluating and debugging inference/solving performance;
  • A Blackjax-like interface provided by mccube.extensions (Coming Soon);
  • A custom solver for using MCC in Diffrax, also provided bymccube.extensions (Coming Soon).

In addition, like the samplers in Blackjax, MCCube can easily be integrated with probabilistic programming languages (PPLs), as long as they can provide a (potentially unnormalized) log-density function.

[!warning]
This package is currently a work-in-progress/experimental. Expect bugs, API instability, and treat all results with a healthy degree of skepticism.

Who should use MCCube?

MCCube should appeal to:

  • Users of Blackjax (people who need/want modular GPU/TPU capable samplers);
  • Users of Diffrax (people who need to solve SDEs/CDEs);
  • Markov chain cubature researchers/developers.

Installation

To install the base pacakge:

pip install mccube

If you want all the extras provided in mccube.extensions:

pip install mccube[extras]

Requires Python 3.9+, JAX 0.4.11+, and Equinox 0.10.5+.

By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA compatible accelerators (GPUs/TPUs) please follow these installation instructions. Windows support for JAX is currently experimental; WSL2 is the recommended approach for using JAX on Windows.

Documentation

Coming soon at https://mccube.readthedocs.io/.

What is Markov chain cubature?

MCC is an approach to constructing a Cubature on Wiener Space which does not suffer from exponential scaling in time (particle count explosion), thanks to the utilization of (partitioned) recombination in the Cubature step/transition kernel.

Quick Example

The below toy example demonstrates MCCube for inferring the moments of a ten dimensional Gaussian, with mean two and diagonal covariance six, given its logdensity function. More in-depth examples are coming soon.

import jax
import numpy as np
from jax.scipy.stats import multivariate_normal

from mccube import MCCubatureStep, mccubaturesolve, minimal_cubature_formula
from mccube.components import LangevinDiffusionPropagator, MonteCarloRecombinator
from mccube.metrics import cubature_target_error

# Setup the problem.
n_particles = 8192
target_dimension = 2

rng = np.random.default_rng(42)
prior_particles = rng.uniform(size=(n_particles, target_dimension))
target_mean = 2 * np.ones(target_dimension)
target_cov = 3 * np.diag(target_mean)


# MCCube expects the logdensity to have call signature (t, p(t), args), allowing the
# density to be time dependant, or to rely on some other generic args.
# Note: You can obtain significantly better performance by defining a custom jvp here.
def target_logdensity(t, p, args):
    return multivariate_normal.logpdf(p, target_mean, target_cov)


# Setup the MCCubature.
recombinator_key = jax.random.PRNGKey(42)
cfv = minimal_cubature_formula(target_dimension, degree=3).vectors
cs = MCCubatureStep(
    propagator=LangevinDiffusionPropagator(cfv),
    recombinator=MonteCarloRecombinator(recombinator_key),
)

# Construct the MCCubature/solve for the MCCubature paths.
mccubature_paths = mccubaturesolve(
    logdensity=target_logdensity,
    transition_kernel=cs,
    initial_particles=prior_particles,
)

# Compare mean and covariance of the inferred cubature to the target.
posterior_particles = mccubature_paths.particles[-1, :, :]
mean_err, cov_err = cubature_target_error(posterior_particles, target_mean, target_cov)
print(f"Mean Error: {mean_err}\n", f"Cov Error: {cov_err}")

Note that mccubaturesolve returns the cubature paths, but does not return any other intermediate step information. If such information is required, a 'visualizer' callback can be used, for example:

from mccube.extensions.visualizers import TensorboardVisualizer

with TensorboardVisualizer() as tbv:
    cubature = mccubaturesolve(..., visualization_callback=tbv)

To make use of the Tensorboard visualization suite remember to run the following command either during/after each experimental run:

tensorboard --logdir=experiments

Citation

Please cite this repository if it has been useful in your work:

@software{mccube2023github,
    author={},
    title={{MCC}ube: Markov chain cubature via {JAX}},
    url={},
    version={<insert current release tag>},
    year={2023},
}

See Also

Some other Python/JAX packages that you may find interesting:

  • PySR High-Performance Symbolic Regression in Python and Julia.
  • Equinox A JAX library for parameterised functions.
  • Diffrax A JAX library providing numerical differential equation solvers.
  • Lineax A JAX library for linear solves and linear least squares.
  • OTT-JAX A JAX library for optimal transport.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mccube-0.0.1.tar.gz (24.3 kB view hashes)

Uploaded Source

Built Distribution

mccube-0.0.1-py3-none-any.whl (28.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page