Python implementation of the No-U-Turn Sampler
Project description
jaxnuts
Python implementation of the No-U-Turn Sampler from Hoffman and Gelman (Algorithm 6) leveraging JAX.
Usage
Import libraries
import jax
import jax.numpy as jnp
import jax.random as random
from jaxnuts.sampler import NUTS
For low dimensional problems such as this simple example, force JAX to use the CPU (avoid GPU overhead)
jax.config.update('jax_platform_name', 'cpu')
Define a log-probability to sample from
def logprob(x):
"""Standard normal"""
return -.5 * jnp.dot(x, x)
Generate samples
key = random.PRNGkey(0)
sampler = NUTS(jnp.ones(2), logp=logprob, target_acceptance=.5, M_adapt=1000)
key, samples, step_size = sampler.sample(1000, key)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jaxnuts-0.0.1.tar.gz
(9.1 kB
view hashes)