Skip to main content

Simulation-based inference in JAX

Project description

sbijax

active ci codecov documentation version

Simulation-based inference in JAX

About

Sbijax is a Python library for neural simulation-based inference and approximate Bayesian computation using JAX. It implements recent methods, such as Simulated-annealing ABC, Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics or Consistency model posterior estimation, as well as methods to compute model diagnostics and for visualizing posterior distributions.

[!CAUTION] ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Examples

Sbijax implements a slim object-oriented API with functional elements stemming from JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm. For example, you can define a neural likelihood estimation method and generate posterior samples like this:

from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
    ), batch_ndims=0)
    return prior

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y


fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install from PyPI, just call the following on the command line:

pip install sbijax

To install the latest GitHub , use:

pip install git+https://github.com/dirmeier/sbijax@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone sbijax and install uv from here.

  2. Install all dependencies using uv sync --all-groups.

  3. Install pre-commit and gitlint via:

    pre-commit install
    gitlint install-hook
    
  4. Create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug.

  5. Implement your contribution and ideally a test case.

  6. Test it by calling make tests, make lints and make format on the (Unix) command line.

  7. Submit a PR 🙂.

Citing sbijax

If you find our work relevant to your research, please consider citing:

@article{dirmeier2024simulation,
  title={Simulation-based inference with the Python Package sbijax},
  author={Dirmeier, Simon and Ulzega, Simone and Mira, Antonietta and Albert, Carlo},
  journal={arXiv preprint arXiv:2409.19435},
  year={2024}
}

Acknowledgements

[!NOTE] 📝 The API of the package is heavily inspired by the excellent Pytorch-based sbi package.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sbijax-0.3.6.tar.gz (13.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sbijax-0.3.6-py3-none-any.whl (91.9 kB view details)

Uploaded Python 3

File details

Details for the file sbijax-0.3.6.tar.gz.

File metadata

  • Download URL: sbijax-0.3.6.tar.gz
  • Upload date:
  • Size: 13.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sbijax-0.3.6.tar.gz
Algorithm Hash digest
SHA256 0eb4d00e7da11d36424458b7ddd3811b3bc2be79d17db04f878615e6168e446c
MD5 1bf0a33ecabdf4a5e05595b8a48d016e
BLAKE2b-256 8199807dc76659261ed3fd09b158706d45d5eaf3c110c22aa2f96c3b10311677

See more details on using hashes here.

File details

Details for the file sbijax-0.3.6-py3-none-any.whl.

File metadata

  • Download URL: sbijax-0.3.6-py3-none-any.whl
  • Upload date:
  • Size: 91.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sbijax-0.3.6-py3-none-any.whl
Algorithm Hash digest
SHA256 ad6366b8ff448c7b873a331756f9d0c8178d979b15a4e9404796b8b3f219c147
MD5 64dc26ebe605d9d827f572fe3ab4b76c
BLAKE2b-256 9b996fbd93d0655acd72418e0b02e1a40b919221919900e68c119823eb0f55fb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page