Skip to main content

Bayesian Inference with JAX

Project description

Bayinx: Bayesian Inference with JAX

Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.

Coming From Stan

There are a few differences between the syntax of Bayinx and Stan. First, as Bayinx is embedded in Python, model definitions are Pythonic and rely on you defining a class that inherits from the Model base class:

class MyModel(Model, init=False):
    # ...

Note: Users should specify init=False to avoid static type checkers from raising irrelevant errors, but more importantly it should remind you that you should NOT implement your own __init__ method!

The data and parameters blocks in Stan are then combined into the attribute definitions with Bayinx. For example, if we are modelling a simple normal distribution with an unknown mean and variance 1, then we might write:

class MyModel(Model, init=False):
    mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
    x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values

    # ...

The model block in Stan is then defined by implementing the model method with Bayinx:

class MyModel(Model, init=False):
    mean: Continuous[Array] = define(shape = ())
    x: Observed[Array] = define(shape = 'n_obs')

    def model(self, target):
        # Equivalent to 'x ~ normal(mean, 1.0)' in Stan
        self.x << Normal(self.mean, 1.0)

        return target

Notice that the ~ operator in Stan has been replaced with <<, and to reference nodes of a model you must work with self.

Note: Bayinx does not currently have something similar to transformed data or transformed parameters, however that is likely to be included in a future release.

You can then construct the variational approximation to the posterior:

import bayinx as byx
from bayinx.flows import DiagAffine
import jax.numpy as jnp

# Fit variational approximation
posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
posterior.configure(flowspecs = [DiagAffine()])
posterior.fit()

This approximation can then be worked with by sampling nodes:

mean_draws = posterior.sample('mean', 10000)
print(mean_draws.mean())

Roadmap

  • Implement OT-Flow: https://arxiv.org/abs/2006.00104
  • Allow shape definitions to include expressions (e.g., shape = 'n_obs + 1' will evaluate to the correct specification)
  • Figure out how to dynamically construct distributions such that parameterizations don't require calling new functions, just defining Exponential(rate ...) vs. Exponential(scale = ...)

Project details


Release history Release notifications | RSS feed

This version

0.5.6

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayinx-0.5.6.tar.gz (53.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayinx-0.5.6-py3-none-any.whl (56.8 kB view details)

Uploaded Python 3

File details

Details for the file bayinx-0.5.6.tar.gz.

File metadata

  • Download URL: bayinx-0.5.6.tar.gz
  • Upload date:
  • Size: 53.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.6.tar.gz
Algorithm Hash digest
SHA256 192ca0b86e7596f15406c16a09e4d9126314f0dc0aeb0b6e9aa6e1ab6aab2daf
MD5 4eb01341be117f5bda5134f9def48867
BLAKE2b-256 20901a242d4e2ff3e516cfa7c3d1b4d3bd3d8331fb956682919f1719379a83c8

See more details on using hashes here.

File details

Details for the file bayinx-0.5.6-py3-none-any.whl.

File metadata

  • Download URL: bayinx-0.5.6-py3-none-any.whl
  • Upload date:
  • Size: 56.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.6-py3-none-any.whl
Algorithm Hash digest
SHA256 ae4bfcad107efce0b8cacd7e012ad9733d8b534fb8041b6519c48cbb68bff0d8
MD5 0f32619b3e0626b7d737fd1b93d4a14f
BLAKE2b-256 aa7fa4f71e64e72c96b7071374c0144091eecb5e585ff53e12f9c858dc551453

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page