Bayesian Inference with JAX
Project description
Bayinx: Bayesian Inference with JAX
The original aim of this project was to build a PPL in Python that is similar in feel to Stan or Nimble(where there is a nice declarative syntax for defining the model) and allows for arbitrary models(e.g., ones with discrete parameters that may not be just integers); most of this goal has been moved to baycian for the foreseeable future.
Part of the reason for this move is that Rust's ability to embed a "nice" DSL is comparitively easier due to Rust macros; I can define syntax similar to Stan and parse it to valid Rust code. Additionally, the current state of bayinx is relatively functional(plus/minus a few things to clean-up and documentation) and it offers enough for one of my other projects: disize! I plan to rewrite disize in Python with JAX, and bayinx makes it easy to handle constraining transformations, filtering for parameters for gradient calculations, etc.
Instead, this project is narrowing on implementing much of Stan's functionality(restricted to continuously parameterized models, point estimation + vi + mcmc, etc) without most of the nice syntax, at least for versions 0.4.#. Therefore, people will work with target directly and return the density like below:
class NormalDist(Model):
x: Parameter[Array] = define(shape = (2,))
def eval(self, data: Dict[str, Array]):
# Constrain parameters
self, target = self.constrain_params() # this does nothing for the current model
# Evaluate x ~ Normal(10.0, 1.0)
target += normal.logprob(self.x(), 10.0, 1.0).sum()
return target
I have ideas for using a context manager and implementing Node: Observed/Stochastic classes that will try and replicate what baycian is trying to do, but that is for the future and versions 0.4.# will retain the functionality needed for disize.
TODO
- For optimization and variational methods offer a way for users to have custom stopping conditions(perhaps stop if a single parameter has converged, etc).
- Control variates for meanfield VI? Look at https://proceedings.mlr.press/v33/ranganath14.html more closely.
- Low-rank affine flow?
- https://arxiv.org/pdf/1803.05649 implement sylvester flows.
- Learn how to generate documentation.
- Figure out how to make transform_pars for flows such that there is no performance loss. Noticing some weird behaviour when adding constraints.
- Look into adaptively tuning ADAM hyperparameters for VI.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bayinx-0.4.1.tar.gz.
File metadata
- Download URL: bayinx-0.4.1.tar.gz
- Upload date:
- Size: 43.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.7.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b497b28164e7e02d51febba6ea33c0a7a47391fdcd54902fea7b38c25f9e1bbd
|
|
| MD5 |
ac2130650d06426ef1f61deac8cb0484
|
|
| BLAKE2b-256 |
a54d39b0fd269887bcdf8ac1a996715f30fdc716a605fbf3ec73db116194371d
|
File details
Details for the file bayinx-0.4.1-py3-none-any.whl.
File metadata
- Download URL: bayinx-0.4.1-py3-none-any.whl
- Upload date:
- Size: 27.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.7.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7a021ff20895acb08db85cd7217df16fe3a403c3b73e2df84640977e1514d262
|
|
| MD5 |
87df5863bebe0cb93c8a7fe167554620
|
|
| BLAKE2b-256 |
7c28790941d1ad89b607365250e811c355c51cf1d8dfcf74be82b6c456f478d6
|