Skip to main content

Faster Hamiltonian trajectories for problems with guesses, using Blackjax

Project description

grapevine

Tests Project Status: WIP – Initial development is in progress, but there has not yet been a stable, usable release suitable for the public. Supported Python versions: 3.12 and newer

JAX/Blackjax implementation of the grapevine method for reusing the solutions of guessing problems embedded in Hamiltonian trajectories.

The grapevine method can dramatically speed up MCMC for statistical with embedded equation solving problems.

Installation

pip install grapevine-mcmc

Usage

First make a suitable log density function.

This function should have two arguments: a set of parameters (a Pytree) and a guess (also a Pytree). It should return the log density of these parameters (a number) and a new guess. It should also be generally compatible with JAX, and will probalbly involve some differentiable numerical solving, for example using optimistix.

Here is a simple example of such a function:

from functools import partial

import jax

from jax.scipy.stats import norm
from jax.scipy.special import expit
from jax import numpy as jnp

import optimistix as optx

# equation solving problems often need 64 bit floats
jax.config.update("jax_enable_x64", True)

solver = optx.Newton(rtol=1e-8, atol=1e-8)
obs = jnp.array(0.7)


def fn(y, args):
    """Equation defining a root-finding problem."""
    a = args
    return y - jnp.tanh(y * expit(a) + 1)


def joint_logdensity(a, obs, guess):
    """An example log density."""
    sol = optx.root_find(fn, solver, guess, args=a)
    log_prior = norm.logpdf(a, loc=0.0, scale=1.0)
    log_likelihood = norm.logpdf(obs, loc=sol.value, scale=0.5)
    return log_prior + log_likelihood, sol.value


posterior_logdensity = partial(joint_logdensity, obs=obs)
posterior_logdensity(a=0.0, guess=0.01)
# (Array(-1.22095095, dtype=float64), Array(0.8952192, dtype=float64))

Now you can run MCMC on your model using GrapeNUTS, the grapevine version of the NUTS sampler!

from grapevine import run_grapenuts

INITIAL_POSITION = jnp.array(0.0)
DEFAULT_GUESS = jnp.array(0.01)
SEED = 1234

key = jax.random.key(SEED)
samples, info = run_grapenuts(
    logdensity_fn=posterior_logdensity,
    rng_key=key,
    init_parameters=INITIAL_POSITION,
    num_warmup=10,
    num_samples=10,
    default_guess=DEFAULT_GUESS,
    progress_bar=False,
    initial_step_size=0.01,
    max_num_doublings=4,
    is_mass_matrix_diagonal=True,
    target_acceptance_rate=0.8,
)
jnp.quantile(samples.position, jnp.array([0.01, 0.5, 0.99]))
# Array([-1.26712677,  0.12950684,  0.93903677], dtype=float64)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grapevine_mcmc-0.1.0.tar.gz (19.2 kB view details)

Uploaded Source

Built Distribution

grapevine_mcmc-0.1.0-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file grapevine_mcmc-0.1.0.tar.gz.

File metadata

  • Download URL: grapevine_mcmc-0.1.0.tar.gz
  • Upload date:
  • Size: 19.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for grapevine_mcmc-0.1.0.tar.gz
Algorithm Hash digest
SHA256 676a4335b1232bb2e8c00c000f9bbce10a64b1525196d203a6b2ff9a0e41130c
MD5 7ed1eb96812e8da608b9af98ec6c1e7e
BLAKE2b-256 7a869c1b40a52dd13ff9598cb205c8b4d8aa0b34a592119246c1c07e5da8504f

See more details on using hashes here.

File details

Details for the file grapevine_mcmc-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for grapevine_mcmc-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3b37167f2b16eb58349cf432b455556113f6fe7770d60fa03f413e6fc3552baf
MD5 023545633fa6e4a810f810b09db3785b
BLAKE2b-256 e5f03e9ee182eb476b093a2d34c32c183ba8cb971444047c0a0ee67dbc54131b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page