Skip to main content

Simulation-based inference in JAX

Project description

sbijax

active ci codecov documentation version

Simulation-based inference in JAX

About

Sbijax is a Python library for neural simulation-based inference and approximate Bayesian computation using JAX. It implements recent methods, such as Simulated-annealing ABC, Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics or Consistency model posterior estimation, as well as methods to compute model diagnostics and for visualizing posterior distributions.

[!CAUTION] ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Examples

Sbijax implements a slim object-oriented API with functional elements stemming from JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm. For example, you can define a neural likelihood estimation method and generate posterior samples like this:

from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
    ), batch_ndims=0)
    return prior

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y


fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install from PyPI, just call the following on the command line:

pip install sbijax

To install the latest GitHub , use:

pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone sbijax and install hatch via pip install hatch,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. implement your contribution and ideally a test case,
  4. test it by calling make tests, make lints and make format on the (Unix) command line,
  5. submit a PR 🙂

Acknowledgements

[!NOTE] 📝 The API of the package is heavily inspired by the excellent Pytorch-based sbi package.

Author

Simon Dirmeier sfyrbnd @ pm me

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sbijax-0.3.0.tar.gz (20.2 MB view details)

Uploaded Source

Built Distribution

sbijax-0.3.0-py3-none-any.whl (64.3 kB view details)

Uploaded Python 3

File details

Details for the file sbijax-0.3.0.tar.gz.

File metadata

  • Download URL: sbijax-0.3.0.tar.gz
  • Upload date:
  • Size: 20.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for sbijax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 3bc95817cf1966c5e78b813d75e61c3ebea4872f6056b9cef9f0405233d37724
MD5 71b58d31661b61ca0a3eac7d5dd48444
BLAKE2b-256 4ce5b2212322c1d82bd8f31102d1462a4d0e5155a073e202d0b9060d94193cd6

See more details on using hashes here.

File details

Details for the file sbijax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: sbijax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 64.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for sbijax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6561755757f1ab67c318f161f063ae45e704de258b75967c104378baba4a98c1
MD5 041ac3ec0894459abd2d13dc90d067c6
BLAKE2b-256 c10cf0cb57100bb3511cc83abd88f42566270720bda7bfd2f7a981257edaa68a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page