Markov chain cubature via JAX.
Project description
MCCube
Markov chain cubature via JAX
MCCube provides the tools for performing Markov chain cubature in diffrax.
Key features:
- Custom terms, paths, and solvers that provide a painless means to perform MCC in diffrax.
- A small library of recombination kernels, convential cubature formulae, and metrics.
Installation
To install the base pacakge:
pip install mccube
Requires Python 3.12+, Diffrax 0.5.0+, and Equinox 0.11.3+.
By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA compatible accelerators (GPUs/TPUs) please follow these installation instructions. Windows support for JAX is currently experimental; WSL2 is the recommended approach for using JAX on Windows.
Documentation
Available at https://mccube.readthedocs.io/.
What is Markov chain cubature?
MCC is an approach to constructing a Cubature on Wiener Space which does not suffer from exponential scaling in time (particle count explosion), thanks to the utilization of (partitioned) recombination in the (approximate) cubature kernel.
Example
import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.stats import multivariate_normal
from mccube import (
GaussianRegion,
Hadamard,
LocalLinearCubaturePath,
MCCSolver,
MCCTerm,
MonteCarloKernel,
gaussian_wasserstein_metric,
)
key = jr.PRNGKey(42)
n, d = 512, 10
t0 = 0.0
epochs = 512
dt0 = 0.05
t1 = t0 + dt0 * epochs
y0 = jnp.ones((n, d))
target_mean = 2 * jnp.ones(d)
target_cov = 3 * jnp.eye(d)
def logdensity(p):
return multivariate_normal.logpdf(p, mean=target_mean, cov=target_cov)
ode = diffrax.ODETerm(lambda t, p, args: jax.vmap(jax.grad(logdensity))(p))
cde = diffrax.WeaklyDiagonalControlTerm(
lambda t, p, args: jnp.sqrt(2.0),
LocalLinearCubaturePath(Hadamard(GaussianRegion(d))),
)
terms = MCCTerm(ode, cde)
solver = MCCSolver(diffrax.Euler(), MonteCarloKernel(n, key=key))
sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0)
res_mean = jnp.mean(sol.ys[-1], axis=0)
res_cov = jnp.cov(sol.ys[-1], rowvar=False)
metric = gaussian_wasserstein_metric((target_mean, res_mean), (target_cov, res_cov))
print(f"Result 2-Wasserstein distance: {metric}")
Citation
Please cite this repository if it has been useful in your work:
@software{mccube2023github,
author={},
title={{MCC}ube: Markov chain cubature via {JAX}},
url={},
version={<insert current release tag>},
year={2023},
}
See Also
Some other Python/JAX packages that you may find interesting:
- Markov-Chain-Cubature A PyTorch implementation of Markov Chain Cubature.
- PySR High-Performance Symbolic Regression in Python and Julia.
- Equinox A JAX library for parameterised functions.
- Diffrax A JAX library providing numerical differential equation solvers.
- Lineax A JAX library for linear solves and linear least squares.
- OTT-JAX A JAX library for optimal transport.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mccube-0.0.3.tar.gz
.
File metadata
- Download URL: mccube-0.0.3.tar.gz
- Upload date:
- Size: 23.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2de8f20f7fc167a6bad6dd15a7aaa1a24a3d372f234b8b6319020016c2f54e6e |
|
MD5 | 5dc681c3b6f67c5c17a95e8d2e637b82 |
|
BLAKE2b-256 | 82693a70c96af8955d6f414392fa1b13e6b8ddeb628ec474959a678bdb22fea3 |
File details
Details for the file mccube-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: mccube-0.0.3-py3-none-any.whl
- Upload date:
- Size: 31.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b5e88dee14edf0dc97346216e0d660663a78507350a53343147951e8d9c69fc2 |
|
MD5 | f28a157a69b1b9752a09ed446d2e2a94 |
|
BLAKE2b-256 | c48b7752f21ef44d4ca346614db339a1112c2f36dc5685ed057b5742b984f06f |