Skip to main content

Bayesian Inference with JAX

Project description

Bayinx: Bayesian Inference with JAX

Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.

Coming From Stan

There are a few differences between the syntax of Bayinx and Stan. First, as Bayinx is embedded in Python, model definitions are Pythonic and rely on you defining a class that inherits from the Model base class:

class MyModel(Model, init=False):
    # ...

Note: Users should specify init=False to avoid static type checkers from raising irrelevant errors, but more importantly it should remind you that you should NOT implement your own __init__ method!

The data and parameters blocks in Stan are then combined into the attribute definitions with Bayinx. For example, if we are modelling a simple normal distribution with an unknown mean and variance 1, then we might write:

class MyModel(Model, init=False):
    mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
    x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values

    # ...

The model block in Stan is then defined by implementing the model method with Bayinx:

class MyModel(Model, init=False):
    mean: Continuous[Array] = define(shape = ())
    x: Observed[Array] = define(shape = 'n_obs')

    def model(self, target):
        # Equivalent to 'x ~ normal(mean, 1.0)' in Stan
        self.x << Normal(self.mean, 1.0)

        return target

Notice that the ~ operator in Stan has been replaced with <<, and to reference nodes of a model you must work with self.

Note: Bayinx does not currently have something similar to transformed data or transformed parameters, however that is likely to be included in a future release.

You can then construct the variational approximation to the posterior:

import bayinx as byx
from bayinx.flows import DiagAffine
import jax.numpy as jnp

# Fit variational approximation
posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
posterior.configure(flowspecs = [DiagAffine()])
posterior.fit()

This approximation can then be worked with by sampling nodes:

mean_draws = posterior.sample('mean', 10000)
print(mean_draws.mean())

Roadmap

Project details


Release history Release notifications | RSS feed

This version

0.5.4

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayinx-0.5.4.tar.gz (51.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayinx-0.5.4-py3-none-any.whl (41.1 kB view details)

Uploaded Python 3

File details

Details for the file bayinx-0.5.4.tar.gz.

File metadata

  • Download URL: bayinx-0.5.4.tar.gz
  • Upload date:
  • Size: 51.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.15 {"installer":{"name":"uv","version":"0.9.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.4.tar.gz
Algorithm Hash digest
SHA256 97906a0a3585bed15a51fc4b1b7cc6a46b27b6dedcf58e6bd6009f439a42ce08
MD5 de9c9952a794a8e7aa39c2f51adbdb93
BLAKE2b-256 4a3325b325c170fb54d1d2ff946ceb85b119b895d4385a0b9e19ac0e61c0f796

See more details on using hashes here.

File details

Details for the file bayinx-0.5.4-py3-none-any.whl.

File metadata

  • Download URL: bayinx-0.5.4-py3-none-any.whl
  • Upload date:
  • Size: 41.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.15 {"installer":{"name":"uv","version":"0.9.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 9045ee59485ce0182416c2b0cc020499e3a325aa7140df3f3fb2ce63ed2427e9
MD5 f0178fe132fb1a86a64b4cebd7004372
BLAKE2b-256 8005cf4f700946a4d751330d27a5f58bf0079adf15d1b1f80b29f356c85647a3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page