Skip to main content

Bayesian Inference with JAX

Project description

Bayinx: Bayesian Inference with JAX

Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.

Coming From Stan

There are a few differences between the syntax of Bayinx and Stan. First, as Bayinx is embedded in Python, model definitions are Pythonic and rely on you defining a class that inherits from the Model base class:

class MyModel(Model):
    # ...

Note: You should NOT implement your own __init__ method!

The data and parameters blocks in Stan are then combined into the attribute definitions with Bayinx. For example, if we are modelling a simple normal distribution with an unknown mean and variance 1, then we might write:

class MyModel(Model):
    mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
    x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values

    # ...

The model block in Stan is then defined by implementing the model method with Bayinx:

class MyModel(Model):
    mean: Continuous[Array] = define(shape = ())
    x: Observed[Array] = define(shape = 'n_obs')

    def model(self, target):
        # Equivalent to 'x ~ normal(mean, 1.0)' in Stan
        self.x << Normal(self.mean, 1.0)

        return target

Notice that the ~ operator in Stan has been replaced with <<, and to reference nodes of a model you must work with self.

Note: Bayinx does not currently have something similar to transformed data or transformed parameters, however that is likely to be included in a future release.

You can then construct the variational approximation to the posterior:

import bayinx as byx
from bayinx.flows import DiagAffine
import jax.numpy as jnp

# Fit variational approximation
posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
posterior.configure(flowspecs = [DiagAffine()])
posterior.fit()

This approximation can then be worked with by sampling nodes:

mean_draws = posterior.sample('mean', 10000)
print(mean_draws.mean())

Roadmap

  • Implement OT-Flow: https://arxiv.org/abs/2006.00104
  • Allow shape definitions to include expressions (e.g., shape = 'n_obs + 1' will evaluate to the correct specification)
  • Find a nice way to track the ELBO trajectory to implement early stoppage (tolerance currently does nothing).
  • Allow users to specify custom tolerance criteria.

Project details


Release history Release notifications | RSS feed

This version

0.5.7

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayinx-0.5.7.tar.gz (62.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayinx-0.5.7-py3-none-any.whl (56.7 kB view details)

Uploaded Python 3

File details

Details for the file bayinx-0.5.7.tar.gz.

File metadata

  • Download URL: bayinx-0.5.7.tar.gz
  • Upload date:
  • Size: 62.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.7.tar.gz
Algorithm Hash digest
SHA256 ca1e9dd202b07442c9a90474b8de4691435597c3bd007a9abb615a02b47910d3
MD5 4ad5941a439ec1a8799130c882ef0fb3
BLAKE2b-256 4dbe6a1b218d07cfe4d425522b0bbebb58984465d3749d9f39f48801c5efc943

See more details on using hashes here.

File details

Details for the file bayinx-0.5.7-py3-none-any.whl.

File metadata

  • Download URL: bayinx-0.5.7-py3-none-any.whl
  • Upload date:
  • Size: 56.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.7-py3-none-any.whl
Algorithm Hash digest
SHA256 04e9edb720cb1dbae68d841ccbd0bcadb23de9feecc353d16ee3d0287d75fc49
MD5 60e2aee87601300f04d9ec497044c1cc
BLAKE2b-256 48a36a2cf4df7736271f16c4ee989039f33e6a1c371032269aa7fc4eafd05a3d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page