Skip to main content

A unified interface for probabilistic inference in JAX + Equinox

Project description

Inferix: A unified interface for probabilistic inference in JAX + Equinox

Inferix
Author Gary Allen
Homepage github.com/gvcallen/inferix

Installation

Inferix can be installed using pip directly:

pip install inferix

Motivation

In the JAX ecosystem, you typically have to choose between two extremes for Bayesian inference:

  • Wrappers around lower-level drivers (like BlackJAX or PolyChord), which force you to manually manage while-loops, PRNG keys, buffers, and algorithmic states.
  • High-level Probabilistic Programming Languages (PPLs) (like NumPyro or PyMC), which are user-friendly but force you to rewrite your models using their specific domain-specific languages and distribution primitives.

The goal of Inferix is to be a middle option that mirrors the API of Optimistix. It is designed for engineers and scientists who already have a forward model written in pure JAX, and just want to sample from it without managing boilerplate or adopting a heavy PPL framework.

Inferix wraps low-level algorithms in declarative Equinox modules and handles all the XLA-compiled control flow, hypercube reparameterizations, and data packaging under the hood.The library is built around a unified API structure: you instantiate an algorithm (for example inferix.MCMC) with your desired hyperparameters, and pass it to a unified runner (inferix.mcmc_sample or inferix.nested_sample). Inferix currently supports both MCMC and Nested Sampling paradigms, with a native bridge to Blackjax (for NSS and NUTS) and and PolyChord.

``python import jax import jax.numpy as jnp import inferix

1. Define your target functions (Pure JAX)

def my_circuit_likelihood(theta, args): # e.g., A complex differentiable physics simulation return ...

def my_prior_transform(u, args): # A mapping from the uniform unit hypercube coordinates u to physical parameters theta return ...

2. Instantiate your sampler of choice e.g. inferix.NSS or inferix.PolyChord

sampler = inferix.NSS(num_delete=10, num_inner_steps=20)

3. Execute the run

key = jax.random.PRNGKey(42) solution = inferix.nested_sample( log_likelihood_fn=my_circuit_likelihood, prior_transform_fn=my_prior_transform, sampler=sampler, ndims=5, key=key, logZ_convergence=1e-3, )

4. Access the results

print(f"Final log-Evidence (logZ): {solution.logZ} ± {solution.logZ_err}") print(f"Total steps taken: {solution.num_steps}") physical_samples = solution.dead_points ``

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

inferix-0.1.0.tar.gz (3.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

inferix-0.1.0-py3-none-any.whl (6.7 kB view details)

Uploaded Python 3

File details

Details for the file inferix-0.1.0.tar.gz.

File metadata

  • Download URL: inferix-0.1.0.tar.gz
  • Upload date:
  • Size: 3.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for inferix-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b671357fc2f781f6926d91de867eff6dc34fcb043e7495e129a7c30d73a61f00
MD5 69e8834795f91ede20309773392f1f71
BLAKE2b-256 41324f899f615c113bca1e9bf3e599d88df6f929251d396562925b3d357a6a36

See more details on using hashes here.

File details

Details for the file inferix-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: inferix-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 6.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for inferix-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6f7d7272130f838a2aa45daead0e540b6cc358ffdf8f8834cf1d0efe625f96e6
MD5 0e048a3d4709dbfeeb58cb5ea39505ed
BLAKE2b-256 c43934440b3699b8974c5e05c7754cfec60dc6ec53cb82eeec24566e10554556

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page