Simulation-based inference in JAX
Project description
sbijax
Simulation-based inference in JAX
About
sbijax
implements several algorithms for simulation-based inference in
JAX using Haiku,
Distrax and BlackJAX. Specifically, sbijax
implements
- Sequential Monte Carlo ABC (
SMCABC
) - Neural Likelihood Estimation (
SNL
) - Surjective Neural Likelihood Estimation (
SSNL
) - Neural Posterior Estimation C (short
SNP
) - Contrastive Neural Ratio Estimation (short
SNR
) - Neural Approximate Sufficient Statistics (
SNASS
) - Neural Approximate Slice Sufficient Statistics (
SNASSS
) - Flow matching posterior estimation (
SFMPE
) - Consistency model posterior estimation (
SCMPE
)
where the acronyms in parentheses denote the names of the methods in sbijax
.
[!CAUTION] ⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.
Examples
You can find several self-contained examples on how to use the algorithms in examples.
Documentation
Documentation can be found here.
Installation
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install from PyPI, just call the following on the command line:
pip install sbijax
To install the latest GitHub , use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
Acknowledgements
[!NOTE] 📝 The API of the package is heavily inspired by the excellent Pytorch-based
sbi
package which is substantially more feature-complete and user-friendly, and better documented.
Author
Simon Dirmeier sfyrbnd @ pm me
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.