Skip to main content

Faster Hamiltonian trajectories for problems with guesses, using Blackjax

Project description

grapevine

Tests Project Status: WIP – Initial development is in progress, but there has not yet been a stable, usable release suitable for the public. Supported Python versions: 3.12 and newer

JAX/Blackjax implementation of the grapevine method for reusing the solutions of guessing problems embedded in Hamiltonian trajectories.

The grapevine method can dramatically speed up MCMC for statistical with embedded equation solving problems.

Installation

pip install grapevine-mcmc

Usage

First make a suitable log density function.

This function should have two arguments: a set of parameters (a Pytree) and a guess (also a Pytree). It should return the log density of these parameters (a number) and a new guess. It should also be generally compatible with JAX, and will probalbly involve some differentiable numerical solving, for example using optimistix.

Here is a simple example of such a function:

from functools import partial

import jax

from jax.scipy.stats import norm
from jax.scipy.special import expit
from jax import numpy as jnp

import optimistix as optx

# equation solving problems often need 64 bit floats
jax.config.update("jax_enable_x64", True)

solver = optx.Newton(rtol=1e-8, atol=1e-8)
obs = jnp.array(0.7)


def fn(y, args):
    """Equation defining a root-finding problem."""
    a = args
    return y - jnp.tanh(y * expit(a) + 1)


def joint_logdensity(a, obs, guess):
    """An example log density."""
    sol = optx.root_find(fn, solver, guess, args=a)
    log_prior = norm.logpdf(a, loc=0.0, scale=1.0)
    log_likelihood = norm.logpdf(obs, loc=sol.value, scale=0.5)
    return log_prior + log_likelihood, sol.value


posterior_logdensity = partial(joint_logdensity, obs=obs)
posterior_logdensity(a=0.0, guess=0.01)
# (Array(-1.22095095, dtype=float64), Array(0.8952192, dtype=float64))

Now you can run MCMC on your model using GrapeNUTS, the grapevine version of the NUTS sampler!

from grapevine import run_grapenuts

INITIAL_POSITION = jnp.array(0.0)
DEFAULT_GUESS = jnp.array(0.01)
SEED = 1234

key = jax.random.key(SEED)
samples, info = run_grapenuts(
    logdensity_fn=posterior_logdensity,
    rng_key=key,
    init_parameters=INITIAL_POSITION,
    num_warmup=10,
    num_samples=10,
    default_guess=DEFAULT_GUESS,
    progress_bar=False,
    initial_step_size=0.01,
    max_num_doublings=4,
    is_mass_matrix_diagonal=True,
    target_acceptance_rate=0.8,
)
jnp.quantile(samples.position, jnp.array([0.01, 0.5, 0.99]))
# Array([-1.26712677,  0.12950684,  0.93903677], dtype=float64)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grapevine_mcmc-0.2.0.tar.gz (417.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

grapevine_mcmc-0.2.0-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file grapevine_mcmc-0.2.0.tar.gz.

File metadata

  • Download URL: grapevine_mcmc-0.2.0.tar.gz
  • Upload date:
  • Size: 417.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for grapevine_mcmc-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c96b01d019360052fbd72c9a860473d57e530c2f9a25470018cccfa551034e85
MD5 d4bc46bd9e7ec78a321749a936e96d3a
BLAKE2b-256 245fe69dc2db89a25c8d3ee18347d07e29664f3b42272eb125b1c565c80b65eb

See more details on using hashes here.

Provenance

The following attestation bundles were made for grapevine_mcmc-0.2.0.tar.gz:

Publisher: release.yml on dtu-qmcm/grapevine

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file grapevine_mcmc-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: grapevine_mcmc-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for grapevine_mcmc-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b9b43304b0d37e48b57bbbdad0cc5e47052d85d5597adab3a1083232d2c88d95
MD5 5bf8de7573ac17a513a7eb5dbf1fb343
BLAKE2b-256 1cb963ca67b8bee3fadd13cdbb1d5fce1d616eb32981b2c12c25ee73e665f10e

See more details on using hashes here.

Provenance

The following attestation bundles were made for grapevine_mcmc-0.2.0-py3-none-any.whl:

Publisher: release.yml on dtu-qmcm/grapevine

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page