Skip to main content

Flexible and fast sampling in Python

Project description

BlackJAX

Continuous integration codecov PyPI version

What is BlackJAX?

BlackJAX is a library of samplers for JAX that works on CPU as well as GPU.

It is not a probabilistic programming library. However it integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.

Who should use BlackJAX?

BlackJAX should appeal to those who:

  • Have a logpdf and just need a sampler;
  • Need more than a general-purpose sampler;
  • Want to sample on GPU;
  • Want to build upon robust elementary blocks for their research;
  • Are building a probabilistic programming language;
  • Want to learn how sampling algorithms work.

Quickstart

Installation

You can install BlackJAX using pip:

pip install blackjax

or via conda-forge:

conda install -c conda-forge blackjax

Nightly builds (bleeding edge) of Blackjax can also be installed using pip:

pip install blackjax-nightly

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. If you want to use BlackJAX on GPU/TPU we recommend you follow these instructions to install JAX with the relevant hardware acceleration support.

Example

Let us look at a simple self-contained example sampling with NUTS:

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np

import blackjax

observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
    logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
    return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for _ in range(100):
    rng_key, nuts_key = jax.random.split(rng_key)
    state, _ = nuts.step(nuts_key, state)

See the documentation for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.

Philosophy

What is BlackJAX?

BlackJAX bridges the gap between "one liner" frameworks and modular, customizable libraries.

Users can import the library and interact with robust, well-tested and performant samplers with a few lines of code. These samplers are aimed at PPL developers, or people who have a logpdf and just need a sampler that works.

But the true strength of BlackJAX lies in its internals and how they can be used to experiment quickly on existing or new sampling schemes. This lower level exposes the building blocks of inference algorithms: integrators, proposal, momentum generators, etc and makes it easy to combine them to build new algorithms. It provides an opportunity to accelerate research on sampling algorithms by providing robust, performant and reusable code.

Why BlackJAX?

Sampling algorithms are too often integrated into PPLs and not decoupled from the rest of the framework, making them hard to use for people who do not need the modeling language to build their logpdf. Their implementation is most of the time monolithic and it is impossible to reuse parts of the algorithm to build custom kernels. BlackJAX solves both problems.

How does it work?

BlackJAX allows to build arbitrarily complex algorithms because it is built around a very general pattern. Everything that takes a state and returns a state is a transition kernel, and is implemented as:

new_state, info =  kernel(rng_key, state)

kernels are stateless functions and all follow the same API; state and information related to the transition are returned separately. They can thus be easily composed and exchanged. We specialize these kernels by closure instead of passing parameters.

Contributions

Please follow our short guide.

Citing Blackjax

To cite this repository:

@software{blackjax2020github,
  author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi},
  title = {{B}lackjax: A sampling library for {JAX}},
  url = {http://github.com/blackjax-devs/blackjax},
  version = {<insert current release tag>},
  year = {2023},
}

In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the main branch.

Acknowledgements

Some details of the NUTS implementation were largely inspired by Numpyro's.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

blackjax-nightly-1.0.0.post23.tar.gz (299.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

blackjax_nightly-1.0.0.post23-py3-none-any.whl (311.8 kB view details)

Uploaded Python 3

File details

Details for the file blackjax-nightly-1.0.0.post23.tar.gz.

File metadata

  • Download URL: blackjax-nightly-1.0.0.post23.tar.gz
  • Upload date:
  • Size: 299.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for blackjax-nightly-1.0.0.post23.tar.gz
Algorithm Hash digest
SHA256 9b6f2bbac4c7d565c808784d1b74065ac980408c02a33134ff0c935e094bfccf
MD5 662bb81973ae3f4bc8db4b5fdbe21f8d
BLAKE2b-256 7e5a474d0aaef6b21f3a67b91f8d3db21b1422e5c913275736400f9565b45b73

See more details on using hashes here.

File details

Details for the file blackjax_nightly-1.0.0.post23-py3-none-any.whl.

File metadata

File hashes

Hashes for blackjax_nightly-1.0.0.post23-py3-none-any.whl
Algorithm Hash digest
SHA256 b33dc7b9498fc6f6c7ef16c75a04c11db4546e7fde6a97c9ee610d40d12b9456
MD5 50c457f875be2412738d98e213787c93
BLAKE2b-256 05019822c65bdc641c6be74715ba225f6e214bdfc0154672689579457e7824e1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page