Bayesian Inference with JAX
Project description
Bayinx: Bayesian Inference with JAX
Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.
Coming From Stan
There are a few differences between the syntax of Bayinx and Stan.
First, as Bayinx is embedded in Python, model definitions are Pythonic and
rely on you defining a class that inherits from the Model base class:
class MyModel(Model, init=False):
# ...
Note: Users should specify
init=Falseto avoid static type checkers from raising irrelevant errors, but more importantly it should remind you that you should NOT implement your own__init__method!
The data and parameters blocks in Stan are then combined into the attribute
definitions with Bayinx. For example, if we are modelling a simple normal distribution
with an unknown mean and variance 1, then we might write:
class MyModel(Model, init=False):
mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values
# ...
The model block in Stan is then defined by implementing the model method with Bayinx:
class MyModel(Model, init=False):
mean: Continuous[Array] = define(shape = ())
x: Observed[Array] = define(shape = 'n_obs')
def model(self, target):
# Equivalent to 'x ~ normal(mean, 1.0)' in Stan
self.x << Normal(self.mean, 1.0)
return target
Notice that the ~ operator in Stan has been replaced with <<, and to reference nodes of a model you must work with self.
Note: Bayinx does not currently have something similar to
transformed dataortransformed parameters, however that is likely to be included in a future release.
You can then construct the variational approximation to the posterior:
import bayinx as byx
from bayinx.flows import DiagAffine
import jax.numpy as jnp
# Fit variational approximation
posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
posterior.configure(flowspecs = [DiagAffine()])
posterior.fit()
This approximation can then be worked with by sampling nodes:
mean_draws = posterior.sample('mean', 10000)
print(mean_draws.mean())
Roadmap
- Implement OT-Flow: https://arxiv.org/abs/2006.00104
- Allow shape definitions to include expressions (e.g., shape = 'n_obs + 1' will evaluate to the correct specification)
- Figure out how to dynamically construct distributions such that parameterizations don't require calling new functions, just defining
Exponential(rate ...)vs.Exponential(scale = ...)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bayinx-0.5.6.tar.gz.
File metadata
- Download URL: bayinx-0.5.6.tar.gz
- Upload date:
- Size: 53.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
192ca0b86e7596f15406c16a09e4d9126314f0dc0aeb0b6e9aa6e1ab6aab2daf
|
|
| MD5 |
4eb01341be117f5bda5134f9def48867
|
|
| BLAKE2b-256 |
20901a242d4e2ff3e516cfa7c3d1b4d3bd3d8331fb956682919f1719379a83c8
|
File details
Details for the file bayinx-0.5.6-py3-none-any.whl.
File metadata
- Download URL: bayinx-0.5.6-py3-none-any.whl
- Upload date:
- Size: 56.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ae4bfcad107efce0b8cacd7e012ad9733d8b534fb8041b6519c48cbb68bff0d8
|
|
| MD5 |
0f32619b3e0626b7d737fd1b93d4a14f
|
|
| BLAKE2b-256 |
aa7fa4f71e64e72c96b7071374c0144091eecb5e585ff53e12f9c858dc551453
|