Faster Hamiltonian trajectories for problems with guesses, using Blackjax
Project description
grapevine
JAX/Blackjax implementation of the grapevine method for reusing the solutions of guessing problems embedded in Hamiltonian trajectories.
The grapevine method can dramatically speed up MCMC for statistical with embedded equation solving problems.
Installation
pip install grapevine-mcmc
Usage
First make a suitable log density function.
This function should have two arguments: a set of parameters (a Pytree) and a guess (also a Pytree). It should return the log density of these parameters (a number) and a new guess. It should also be generally compatible with JAX, and will probalbly involve some differentiable numerical solving, for example using optimistix.
Here is a simple example of such a function:
from functools import partial
import jax
from jax.scipy.stats import norm
from jax.scipy.special import expit
from jax import numpy as jnp
import optimistix as optx
# equation solving problems often need 64 bit floats
jax.config.update("jax_enable_x64", True)
solver = optx.Newton(rtol=1e-8, atol=1e-8)
obs = jnp.array(0.7)
def fn(y, args):
"""Equation defining a root-finding problem."""
a = args
return y - jnp.tanh(y * expit(a) + 1)
def joint_logdensity(a, obs, guess):
"""An example log density."""
sol = optx.root_find(fn, solver, guess, args=a)
log_prior = norm.logpdf(a, loc=0.0, scale=1.0)
log_likelihood = norm.logpdf(obs, loc=sol.value, scale=0.5)
return log_prior + log_likelihood, sol.value
posterior_logdensity = partial(joint_logdensity, obs=obs)
posterior_logdensity(a=0.0, guess=0.01)
# (Array(-1.22095095, dtype=float64), Array(0.8952192, dtype=float64))
Now you can run MCMC on your model using GrapeNUTS, the grapevine version of the NUTS sampler!
from grapevine import run_grapenuts
INITIAL_POSITION = jnp.array(0.0)
DEFAULT_GUESS = jnp.array(0.01)
SEED = 1234
key = jax.random.key(SEED)
samples, info = run_grapenuts(
logdensity_fn=posterior_logdensity,
rng_key=key,
init_parameters=INITIAL_POSITION,
num_warmup=10,
num_samples=10,
default_guess=DEFAULT_GUESS,
progress_bar=False,
initial_step_size=0.01,
max_num_doublings=4,
is_mass_matrix_diagonal=True,
target_acceptance_rate=0.8,
)
jnp.quantile(samples.position, jnp.array([0.01, 0.5, 0.99]))
# Array([-1.26712677, 0.12950684, 0.93903677], dtype=float64)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file grapevine_mcmc-0.1.0.tar.gz
.
File metadata
- Download URL: grapevine_mcmc-0.1.0.tar.gz
- Upload date:
- Size: 19.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 676a4335b1232bb2e8c00c000f9bbce10a64b1525196d203a6b2ff9a0e41130c |
|
MD5 | 7ed1eb96812e8da608b9af98ec6c1e7e |
|
BLAKE2b-256 | 7a869c1b40a52dd13ff9598cb205c8b4d8aa0b34a592119246c1c07e5da8504f |
File details
Details for the file grapevine_mcmc-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: grapevine_mcmc-0.1.0-py3-none-any.whl
- Upload date:
- Size: 7.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3b37167f2b16eb58349cf432b455556113f6fe7770d60fa03f413e6fc3552baf |
|
MD5 | 023545633fa6e4a810f810b09db3785b |
|
BLAKE2b-256 | e5f03e9ee182eb476b093a2d34c32c183ba8cb971444047c0a0ee67dbc54131b |