Skip to main content

Markov chain cubature via JAX.

Project description

MCCube logo

MCCube
Markov chain cubature via JAX

Documentation Status CI pypi version

MCCube provides the tools for performing Markov chain cubature in diffrax.

Key features:

  • Custom terms, paths, and solvers that provide a painless means to perform MCC in diffrax.
  • A small library of recombination kernels, convential cubature formulae, and metrics.

Installation

To install the base pacakge:

pip install mccube

Requires Python 3.12+, Diffrax 0.5.0+, and Equinox 0.11.3+.

By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA compatible accelerators (GPUs/TPUs) please follow these installation instructions. Windows support for JAX is currently experimental; WSL2 is the recommended approach for using JAX on Windows.

Documentation

Available at https://mccube.readthedocs.io/.

What is Markov chain cubature?

MCC is an approach to constructing a Cubature on Wiener Space which does not suffer from exponential scaling in time (particle count explosion), thanks to the utilization of (partitioned) recombination in the (approximate) cubature kernel.

Example

import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.stats import multivariate_normal

from mccube import (
    GaussianRegion,
    Hadamard,
    LocalLinearCubaturePath,
    MCCSolver,
    MCCTerm,
    MonteCarloKernel,
    gaussian_wasserstein_metric,
)

key = jr.PRNGKey(42)
n, d = 512, 10
t0 = 0.0
epochs = 512
dt0 = 0.05
t1 = t0 + dt0 * epochs
y0 = jnp.ones((n, d))

target_mean = 2 * jnp.ones(d)
target_cov = 3 * jnp.eye(d)


def logdensity(p):
    return multivariate_normal.logpdf(p, mean=target_mean, cov=target_cov)


ode = diffrax.ODETerm(lambda t, p, args: jax.vmap(jax.grad(logdensity))(p))
cde = diffrax.WeaklyDiagonalControlTerm(
    lambda t, p, args: jnp.sqrt(2.0),
    LocalLinearCubaturePath(Hadamard(GaussianRegion(d))),
)
terms = MCCTerm(ode, cde)
solver = MCCSolver(diffrax.Euler(), MonteCarloKernel(n, key=key))

sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0)
res_mean = jnp.mean(sol.ys[-1], axis=0)
res_cov = jnp.cov(sol.ys[-1], rowvar=False)
metric = gaussian_wasserstein_metric((target_mean, res_mean), (target_cov, res_cov))

print(f"Result 2-Wasserstein distance: {metric}")

Citation

Please cite this repository if it has been useful in your work:

@software{mccube2023github,
    author={},
    title={{MCC}ube: Markov chain cubature via {JAX}},
    url={},
    version={<insert current release tag>},
    year={2023},
}

See Also

Some other Python/JAX packages that you may find interesting:

  • Markov-Chain-Cubature A PyTorch implementation of Markov Chain Cubature.
  • PySR High-Performance Symbolic Regression in Python and Julia.
  • Equinox A JAX library for parameterised functions.
  • Diffrax A JAX library providing numerical differential equation solvers.
  • Lineax A JAX library for linear solves and linear least squares.
  • OTT-JAX A JAX library for optimal transport.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mccube-0.0.3.tar.gz (23.6 kB view details)

Uploaded Source

Built Distribution

mccube-0.0.3-py3-none-any.whl (31.9 kB view details)

Uploaded Python 3

File details

Details for the file mccube-0.0.3.tar.gz.

File metadata

  • Download URL: mccube-0.0.3.tar.gz
  • Upload date:
  • Size: 23.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for mccube-0.0.3.tar.gz
Algorithm Hash digest
SHA256 2de8f20f7fc167a6bad6dd15a7aaa1a24a3d372f234b8b6319020016c2f54e6e
MD5 5dc681c3b6f67c5c17a95e8d2e637b82
BLAKE2b-256 82693a70c96af8955d6f414392fa1b13e6b8ddeb628ec474959a678bdb22fea3

See more details on using hashes here.

File details

Details for the file mccube-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: mccube-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 31.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for mccube-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 b5e88dee14edf0dc97346216e0d660663a78507350a53343147951e8d9c69fc2
MD5 f28a157a69b1b9752a09ed446d2e2a94
BLAKE2b-256 c48b7752f21ef44d4ca346614db339a1112c2f36dc5685ed057b5742b984f06f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page