Skip to main content

Probabilistic programming with Gen, built on top of JAX.

Project description


Scaling probabilistic programming with programmable inference.

PyPI codecov Ruff Public API: beartyped Discord Shield

Documentation Build status

🔎 What is GenJAX?

Gen is a multi-paradigm (generative, differentiable, incremental) language for probabilistic programming focused on generative functions: computational objects which represent probability measures over structured sample spaces.

GenJAX is an implementation of Gen on top of JAX - exposing the ability to programmatically construct and manipulate generative functions, as well as JIT compile + auto-batch inference computations using generative functions onto GPU devices.

[!TIP] GenJAX is part of a larger ecosystem of probabilistic programming tools based upon Gen. Explore more...

Quickstart

To install GenJAX, run

pip install genjax

Then install JAX using this guide to choose the command for the architecture you're targeting. To run GenJAX without GPU support:

pip install jax[cpu]~=0.4.24

On a Linux machine with a GPU, run the following command:

pip install jax[cuda12]~=0.4.24

Quick example Open In Colab

The following code snippet defines a generative function called beta_bernoulli that

  • takes a shape parameter beta
  • uses this to create and draw a value p from a Beta distribution
  • Flips a coin that returns 1 with probability p, 0 with probability 1-p and returns that value

Then, we create an inference problem (by specifying a posterior target), and utilize sampling importance resampling to give produce single sample estimator of p.

We can JIT compile that entire process, run it in parallel, etc - which we utilize to produce an estimate for p over 50 independent trials of SIR (with K = 50 particles).

import jax
import jax.numpy as jnp
import genjax
from genjax import beta, flip, gen, Target, ChoiceMap
from genjax.inference.smc import ImportanceK

# Create a generative model.
@gen
def beta_bernoulli(α, β):
    p = beta(α, β) @ "p"
    v = flip(p) @ "v"
    return v

@jax.jit
def run_inference(obs: bool):
    # Create an inference query - a posterior target - by specifying
    # the model, arguments to the model, and constraints.
    posterior_target = Target(beta_bernoulli, # the model
                              (2.0, 2.0), # arguments to the model
                              ChoiceMap.d({"v": obs}), # constraints
                            )

    # Use a library algorithm, or design your own - more on that in the docs!
    alg = ImportanceK(posterior_target, k_particles=50)

    # Everything is JAX compatible by default.
    # JIT, vmap, to your heart's content.
    key = jax.random.key(314159)
    sub_keys = jax.random.split(key, 50)
    _, p_chm = jax.vmap(alg.random_weighted, in_axes=(0, None))(
        sub_keys, posterior_target
    )

    # An estimate of `p` over 50 independent trials of SIR (with K = 50 particles).
    return jnp.mean(p_chm["p"])

(run_inference(True), run_inference(False))
(Array(0.6039314, dtype=float32), Array(0.3679334, dtype=float32))

References

Many bits of knowledge have gone into this project -- you can find many of these bits at the MIT Probabilistic Computing Project page under publications. Here's an abbreviated list of high value references:

JAX influences

This project has several JAX-based influences. Here's an abbreviated list:

Acknowledgements

The maintainers of this library would like to acknowledge the JAX and Oryx maintainers for useful discussions and reference code for interpreter-based transformation patterns.

Disclaimer

This is a research project. Expect bugs and sharp edges. Please help by trying out GenJAX, reporting bugs, and letting us know what you think!

Get Involved + Get Support

Pull requests and bug reports are always welcome! Check out our Contributor's Guide for information on how to get started contributing to GenJAX.

The TL;DR; is:

  • send us a pull request,
  • iterate on the feedback + discussion, and
  • get a +1 from a maintainer

in order to get your PR accepted.

Issues should be reported on the GitHub issue tracker.

If you want to discuss an idea for a new feature or ask us a question, discussion occurs primarily in the body of Github Issues

Created and maintained by the MIT Probabilistic Computing Project. All code is licensed under the Apache 2.0 License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

genjax-0.10.3.tar.gz (115.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

genjax-0.10.3-py3-none-any.whl (156.8 kB view details)

Uploaded Python 3

File details

Details for the file genjax-0.10.3.tar.gz.

File metadata

  • Download URL: genjax-0.10.3.tar.gz
  • Upload date:
  • Size: 115.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.5 CPython/3.11.5 Linux/6.8.0-1021-azure

File hashes

Hashes for genjax-0.10.3.tar.gz
Algorithm Hash digest
SHA256 63faf9158bfc4e411ffa6338379cc3ac30d7c93c3c84015a5d2c97c4aa2073ae
MD5 7d8cd15743c5d3cab8d612bd6b22088c
BLAKE2b-256 a76dfe9566c486b6529b60287770581b3a1c0bf30712cbd6fc2c4e3fc399bdb8

See more details on using hashes here.

File details

Details for the file genjax-0.10.3-py3-none-any.whl.

File metadata

  • Download URL: genjax-0.10.3-py3-none-any.whl
  • Upload date:
  • Size: 156.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.5 CPython/3.11.5 Linux/6.8.0-1021-azure

File hashes

Hashes for genjax-0.10.3-py3-none-any.whl
Algorithm Hash digest
SHA256 35265b001a9ec376e70ca4316f42e03a54b3d9b9668e80dec6266258367adfaa
MD5 a5278defda0e08ca8c71d456fcaa4178
BLAKE2b-256 1c63bef01e922aa220b0628d5c4d127aacd61b58dca481c7c38bef51a7bf2649

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page