Monte Carlo Testbed
Project description
mocat
All things Monte Carlo, written in JAX.
-
Markov chain Monte Carlo
-
Transport samplers
- Sequential Monte Carlo samplers (likelihood tempering)
- Stein variational gradient descent
-
Approximate Bayesian computation
- Rejection/Importance ABC
- MCMC ABC
- SMC ABC
-
State-space models
- Particle filtering
- Particle smoothing
- Kalman filtering + smoothing
Install
pip install mocat
Define a target distribution
We always work with the target's potential (negative log density)
from jax import numpy as jnp, random
import matplotlib.pyplot as plt
import mocat
class Rastrigin(mocat.Scenario):
name = "Rastrigin"
def __init__(self,
dim: int = 1,
a: float = 0.5):
self.dim = dim
self.a = a
super().__init__()
def potential(self,
x: jnp.ndarray,
random_key: jnp.ndarray) -> float:
return self.a*self.dim + jnp.sum(x**2 - self.a * jnp.cos(2 * jnp.pi * x), axis=-1)
Compare samplers
Run MALA and HMC with a Robbins-Monro schedule to adapt the stepsize to desired acceptance rate (defined in e.g. mala.tuning
)
random_key = random.PRNGKey(0)
scenario_rastrigin = Rastrigin(5)
n = int(1e5)
mala = mocat.Overdamped()
mala_samps = mocat.run(scenario_rastrigin, mala, n, random_key, correction=mocat.RMMetropolis())
hmc = mocat.HMC(leapfrog_steps=10)
hmc_samps = mocat.run(scenario_rastrigin, hmc, n, random_key, correction=mocat.RMMetropolis())
Plot the first two dimensions along with trace plots and autocorrelation of the potential
fig, axes = plt.subplots(3, 2)
mocat.plot_2d_samples(mala_samps, ax=axes[0,0])
mocat.plot_2d_samples(hmc_samps, ax=axes[0,1])
mocat.trace_plot(mala_samps, last_n=1000, ax=axes[1,0], title=None)
mocat.trace_plot(hmc_samps, last_n=1000, ax=axes[1,1], title=None)
mocat.autocorrelation_plot(mala_samps, ax=axes[2,0], title=None)
mocat.autocorrelation_plot(hmc_samps, ax=axes[2,1], title=None)
axes[0,0].set_title(scenario_rastrigin.name + ': ' + mala.name)
axes[0,1].set_title(scenario_rastrigin.name + ': ' + mala.name)
plt.tight_layout()
Plus functionality for effective sample size, acceptance rate, squared jumping distance, kernelised Stein discrepancies...
Create your own MCMC sampler
class Underdamped(mocat.MCMCSampler):
name = 'Underdamped'
default_correction = mocat.Metropolis()
def __init__(self,
stepsize = None,
leapfrog_steps = 1,
friction = 1.0):
super().__init__()
self.parameters.stepsize = stepsize
self.parameters.leapfrog_steps = leapfrog_steps
self.parameters.friction = friction
self.tuning.target = 0.651
def startup(self,
scenario: Scenario,
n: int,
initial_state: cdict,
initial_extra: cdict,
**kwargs) -> Tuple[cdict, cdict]:
initial_state, initial_extra = super().startup(scenario, n,
initial_state, initial_extra, **kwargs)
initial_extra.random_key, scen_key = random.split(initial_extra.random_key)
initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad(initial_state.value,
scen_key)
if not hasattr(initial_state, 'momenta') or initial_state.momenta.shape[-1] != scenario.dim:
initial_state.momenta = jnp.zeros(scenario.dim)
return initial_state, initial_extra
def always(self, scenario, reject_state, reject_extra):
d = scenario.dim
stepsize = reject_extra.parameters.stepsize
friction = reject_extra.parameters.friction
reject_state.momenta = reject_state.momenta * -1
reject_extra.random_key, subkey = random.split(reject_extra.random_key)
reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \
+ jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,))
return reject_state, reject_extra
def proposal(self,
scenario: Scenario,
reject_state: cdict,
reject_extra: cdict) -> Tuple[cdict, cdict]:
random_keys = random.split(reject_extra.random_key, self.parameters.leapfrog_steps + 1)
reject_extra.random_key = random_keys[0]
all_leapfrog_state = mocat.utils.leapfrog(scenario.potential_and_grad,
reject_state,
reject_extra.parameters.stepsize,
random_keys[1:])
proposed_state = all_leapfrog_state[-1]
proposed_state.momenta *= -1
return proposed_state, reject_extra
def acceptance_probability(self, scenario, reject_state, reject_extra, proposed_state, proposed_extra):
pre_min_alpha = jnp.exp(- proposed_state.potential
+ reject_state.potential
- mocat.utils.gaussian_potential(proposed_state.momenta)
+ mocat.utils.gaussian_potential(reject_state.momenta))
return jnp.minimum(1., pre_min_alpha)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mocat-0.2.6.tar.gz
(63.3 kB
view details)
File details
Details for the file mocat-0.2.6.tar.gz
.
File metadata
- Download URL: mocat-0.2.6.tar.gz
- Upload date:
- Size: 63.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1de08511e20e8d67cbade4ffb974baa5dec1c8df637e64535e3634295c10cf9e |
|
MD5 | 290781065479e92e3012949927f041c1 |
|
BLAKE2b-256 | 49bed3bff6bd31ca6fd1b0985d612246125c5d5a7f57536fffe81e3a9fe0b9a8 |