Simulation-based inference in JAX
Project description
sbijax
Simulation-based inference in JAX
About
Sbijax
is a Python library for neural simulation-based inference and
approximate Bayesian computation using JAX.
It implements recent methods, such as Simulated-annealing ABC,
Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics
or Consistency model posterior estimation, as well as methods to compute model
diagnostics and for visualizing posterior distributions.
[!CAUTION] ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.
Examples
Sbijax
implements a slim object-oriented API with functional elements stemming from
JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm.
For example, you can define a neural likelihood estimation method and generate posterior samples like this:
from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd
def prior_fn():
prior = tfd.JointDistributionNamed(dict(
theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
), batch_ndims=0)
return prior
def simulator_fn(seed, theta):
p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
y = theta["theta"] + p.sample(seed=seed)
return y
fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))
y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)
More self-contained examples can be found in examples.
Documentation
Documentation can be found here.
Installation
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install from PyPI, just call the following on the command line:
pip install sbijax
To install the latest GitHub , use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
Contributing
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.
In order to contribute:
- Clone
sbijax
and installhatch
viapip install hatch
, - create a new branch locally
git checkout -b feature/my-new-feature
orgit checkout -b issue/fixes-bug
, - implement your contribution and ideally a test case,
- test it by calling
make tests
,make lints
andmake format
on the (Unix) command line, - submit a PR 🙂
Acknowledgements
[!NOTE] 📝 The API of the package is heavily inspired by the excellent Pytorch-based
sbi
package.
Author
Simon Dirmeier sfyrbnd @ pm me
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file sbijax-0.3.0.tar.gz
.
File metadata
- Download URL: sbijax-0.3.0.tar.gz
- Upload date:
- Size: 20.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3bc95817cf1966c5e78b813d75e61c3ebea4872f6056b9cef9f0405233d37724 |
|
MD5 | 71b58d31661b61ca0a3eac7d5dd48444 |
|
BLAKE2b-256 | 4ce5b2212322c1d82bd8f31102d1462a4d0e5155a073e202d0b9060d94193cd6 |
File details
Details for the file sbijax-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: sbijax-0.3.0-py3-none-any.whl
- Upload date:
- Size: 64.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6561755757f1ab67c318f161f063ae45e704de258b75967c104378baba4a98c1 |
|
MD5 | 041ac3ec0894459abd2d13dc90d067c6 |
|
BLAKE2b-256 | c10cf0cb57100bb3511cc83abd88f42566270720bda7bfd2f7a981257edaa68a |