Sampling with Blackjax on Aesara
Project description
AeX
The following currently works:
import aesara.tensor as at
import aex
srng = at.random.RandomStream(0)
sigma_rv = srng.normal(1.)
mu_rv = srng.normal(0, 1)
Y_rv = srng.normal(mu_rv, sigma_rv)
sampler = aex.prior_sampler(Y_rv, mu_rv)
sampler(rng_key, 1_000_000)
Coming
Sampling from the posterior distribution using Blackjax's NUTS sampler:
sampler = aex.mcmc({Y_rv: 1.}, aex.NUTS())
samples, info = sampler(rng_key, 1000, 1000)
Sampling from the posterior by arbitrarily combining Blackjax step functions:
sampler = aex.mcmc({Y_rv: 1.}, {[mu_rv, sigma_rv]: aex.NUTS(), Y_rv: aex.RMH()})
samples, info = sampler(rng_key, 1000)
Sampling from the posterior predictive distribution:
sampler = aex.posterior_predictive(trace, Y_rv)
sampler(rng_key, 1000)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
aex-0.0.2.tar.gz
(6.1 kB
view hashes)
Built Distribution
aex-0.0.2-py3-none-any.whl
(2.5 kB
view hashes)