Nested Sampling in JAX
Project description
Main Status:
Develop Status:
Mission: To make nested sampling faster, easier, and more powerful
What is it?
JAXNS is:
- a probabilistic programming framework using nested sampling as the engine;
- coded in JAX in a manner that allows lowering the entire inference algorithm to XLA primitives, which are JIT-compiled for high performance;
- continuously improving on its mission of making nested sampling faster, easier, and more powerful; and
- citable, and you can read an (old) pre-print here: (https://arxiv.org/abs/2012.15286).
Documentation
You can read the documentation here.
Install
Notes:
- JAXNS requires >= Python 3.8.
- It is always highly recommended to use a unique virtual environment for each project.
To use
miniconda
, have it installed, and run
# To create a new env, if necessary
conda create -n jaxns_py python=3.11
conda activate jaxns_py
For end users
Install directly from PyPi,
pip install jaxns
For development
Clone repo git clone https://www.github.com/JoshuaAlbert/jaxns.git
, and install:
cd jaxns
pip install -r requirements.txt
pip install -r requirements-tests.txt
pip install -r requirements-examples.txt
pip install .
Getting help and contributing examples
Do you have a neat Bayesian problem, and want to solve it with JAXNS? I'm really encourage anyone in either the scientific community or industry to get involved and join the discussion forum. Please use the github discussion forum for getting help, or contributing examples/neat use cases.
Quick start
Checkout the examples here.
Caveats
The caveat is that you need to be able to define your likelihood function with JAX. This is usually no big deal because JAX is just a replacement for NumPy and many likelihoods can be expressed such. If you're unfamiliar, take a quick tour of JAX (https://jax.readthedocs.io/en/latest/notebooks/quickstart.html).
Speed test comparison with other nested sampling packages
JAXNS is really fast because it uses JAX. JAXNS is much faster than PolyChord, MultiNEST, and dynesty, typically achieving two to three orders of magnitude improvement in speed on cheap likelihood evaluations. This is shown in (https://arxiv.org/abs/2012.15286). With regards to how efficiently JAXNS used likelihood evaluations, JAXNS prizes exactness over efficiency, however since it employs an adaptive strategy, users can control efficiency by controlling some precision parameters.
Change Log
15 June, 2023 -- JAXNS 2.2.0 released. Added support to allow TFP bijectors to defined transformed distributions. Other minor improvements.
15 April, 2023 -- JAXNS 2.1.0 released. pmap used on outer-most loops allowing efficient device-device communication during parallel runs.
8 March, 2023 -- JAXNS 2.0.1 released. Changed how we're doing annotations to support python 3.8 again.
3 January, 2023 -- JAXNS 2.0 released. Complete overhaul of components. New way to build models.
5 August, 2022 -- JAXNS 1.1.1 released. Pytree shaped priors.
2 June, 2022 -- JAXNS 1.1.0 released. Dynamic sampling takes advantage of adaptive refinement. Parallelisation. Bayesian opt and global opt modules.
30 May, 2022 -- JAXNS 1.0.1 released. Improvements to speed, parallelisation, and structure of code.
9 April, 2022 -- JAXNS 1.0.0 released. Parallel sampling, dynamic search, and adaptive refinement. Global optimiser released.
2 Jun, 2021 -- JAXNS 0.0.7 released.
13 May, 2021 -- JAXNS 0.0.6 released.
8 Mar, 2021 -- JAXNS 0.0.5 released.
8 Mar, 2021 -- JAXNS 0.0.4 released.
7 Mar, 2021 -- JAXNS 0.0.3 released.
28 Feb, 2021 -- JAXNS 0.0.2 released.
28 Feb, 2021 -- JAXNS 0.0.1 released.
1 January, 2021 -- Paper submitted
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.