Stitching together probabilistic models and inference.
Project description
Bayeux
Stitching together models and samplers
bayeux
lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be simple, self descriptive, and helpful. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a pytree) where that log density is finite. Then let bayeux
do the rest!
Installation
pip install bayeux-ml
Quickstart
We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.
import bayeux as bx
import jax
normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)
seed = jax.random.key(0)
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
Read more
- Defining models
- Inspecting models
- Testing and debugging
- Also see
bayeux
integration with numpyro, PyMC, and TFP!
This is not an officially supported Google product.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for bayeux_ml-0.1.14-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c29040c7184d1b17607800c24badb2502b091283bad9d23c245f82874e6086df |
|
MD5 | e12630f857f90ec49dd8ff75bfdeec59 |
|
BLAKE2b-256 | ac21d1b53bdbe15572e8a55147fb0d3f620db49f06f9b7d40cc01d321739c57a |