Skip to main content

Bayesian Inference with JAX

Project description

Bayinx: Bayesian Inference with JAX

Bayinx is an embedded probabilistic programming language in Python, powered by JAX. It is heavily inspired by and aims to have feature parity with Stan, but extends the types of objects you can work with and focuses on normalizing flows variational inference for sampling.

Coming From Stan

There are a few differences between the syntax of Bayinx and Stan. First, as Bayinx is embedded in Python, model definitions are Pythonic and rely on you defining a class that inherits from the Model base class:

class MyModel(Model, init=False):
    # ...

Note: Users should specify init=False to avoid static type checkers from raising irrelevant errors, but more importantly it should remind you that you should NOT implement your own __init__ method!

The data and parameters blocks in Stan are then combined into the attribute definitions with Bayinx. For example, if we are modelling a simple normal distribution with an unknown mean and variance 1, then we might write:

class MyModel(Model, init=False):
    mean: Continuous[Array] = define(shape = ()) # a scalar mean parameter
    x: Observed[Array] = define(shape = 'n_obs') # a vector of observed values

    # ...

The model block in Stan is then defined by implementing the model method with Bayinx:

class MyModel(Model, init=False):
    mean: Continuous[Array] = define(shape = ())
    x: Observed[Array] = define(shape = 'n_obs')

    def model(self, target):
        # Equivalent to 'x ~ normal(mean, 1.0)' in Stan
        self.x << Normal(self.mean, 1.0)

        return target

Notice that the ~ operator in Stan has been replaced with <<, and to reference nodes of a model you must work with self.

Note: Bayinx does not currently have something similar to transformed data or transformed parameters, however that is likely to be included in a future release.

You can then construct the variational approximation to the posterior:

import bayinx as byx
from bayinx.flows import DiagAffine
import jax.numpy as jnp

# Fit variational approximation
posterior = byx.Posterior(MyModel, n_obs = 3, x = jnp.array([-1.0, 0.0, 1.0]))
posterior.configure(flowspecs = [DiagAffine()])
posterior.fit()

This approximation can then be worked with by sampling nodes:

mean_draws = posterior.sample('mean', 10000)
print(mean_draws.mean())

Roadmap

Project details


Release history Release notifications | RSS feed

This version

0.5.3

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayinx-0.5.3.tar.gz (51.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayinx-0.5.3-py3-none-any.whl (41.1 kB view details)

Uploaded Python 3

File details

Details for the file bayinx-0.5.3.tar.gz.

File metadata

  • Download URL: bayinx-0.5.3.tar.gz
  • Upload date:
  • Size: 51.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.15 {"installer":{"name":"uv","version":"0.9.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.3.tar.gz
Algorithm Hash digest
SHA256 ab418cfeed5f04f7655c2da5b30a3b301171bd314a613e2fcdaca52565a38c32
MD5 8b7c76e3253045f2787d00acebc02677
BLAKE2b-256 df40d492a665fafcc2bbcfedb1a3f3a0a91477bd32110ffe2138d0c80feee187

See more details on using hashes here.

File details

Details for the file bayinx-0.5.3-py3-none-any.whl.

File metadata

  • Download URL: bayinx-0.5.3-py3-none-any.whl
  • Upload date:
  • Size: 41.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.15 {"installer":{"name":"uv","version":"0.9.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for bayinx-0.5.3-py3-none-any.whl
Algorithm Hash digest
SHA256 9abb4c4dcc38c63d607123551fc772644a4cc72d160ad3d1a40366cb513831ed
MD5 c6d8dac9aed39d48e5b366e5b8a3bd8a
BLAKE2b-256 3bdd813df102c5424fde863ef3c4267bbc290223bd47b77d5e661898e2a47778

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page