Skip to main content

A user-centered Python package for differentiable probabilistic inference

Project description

Brancher: A user-centered Python package for differentiable probabilistic inference

Brancher allows to design and train differentiable Bayesian models using stochastic variational inference. Brancher is based on the deep learning framework PyTorch.

Building probabilistic models

Probabilistic models are defined symbolically. Random variables can be created as follows:

a = NormalVariable(loc = 0., scale = 1., name = 'a')
b = NormalVariable(loc = 0., scale = 1., name = 'b')

It is possible to chain together random variables by using arithmetic and mathematical functions:

c = NormalVariable(loc = a**2 + BF.sin(b), 
                   scale = BF.exp(b), 
                   name = 'a')

In this way, it is possible to create arbitrarely complex probabilistic models. It is also possible to use all the deep learning tools of PyTorch in order to define probabilistic models with deep neural networks.

Example: Autoregressive modeling

Probabilistic model

Probabilistic models are defined symbolically:

T = 20
driving_noise = 1.
measure_noise = 0.3
x0 = NormalVariable(0., driving_noise, 'x0')
y0 = NormalVariable(x0, measure_noise, 'x0')
b = LogitNormalVariable(0.5, 1., 'b')

x = [x0]
y = [y0]
x_names = ["x0"]
y_names = ["y0"]
for t in range(1,T):
    x_names.append("x{}".format(t))
    y_names.append("y{}".format(t))
    x.append(NormalVariable(b*x[t-1], driving_noise, x_names[t]))
    y.append(NormalVariable(x[t], measure_noise, y_names[t]))
AR_model = ProbabilisticModel(x + y)

Observe data

Once the probabilistic model is define, we can decide which variable is observed:

[yt.observe(data[yt][:, 0, :]) for yt in y]

Autoregressive variational distribution

The variational distribution can have an arbitrary structure:

Qb = LogitNormalVariable(0.5, 0.5, "b", learnable=True)
logit_b_post = DeterministicVariable(0., 'logit_b_post', learnable=True)
Qx = [NormalVariable(0., 1., 'x0', learnable=True)]
Qx_mean = [DeterministicVariable(0., 'x0_mean', learnable=True)]
for t in range(1, T):
    Qx_mean.append(DeterministicVariable(0., x_names[t] + "_mean", learnable=True))
    Qx.append(NormalVariable(BF.sigmoid(logit_b_post)*Qx[t-1] + Qx_mean[t], 1., x_names[t], learnable=True))
variational_posterior = ProbabilisticModel([Qb] + Qx)
model.set_posterior_model(variational_posterior)

Inference

Now that the models are specified we can perform approximate inference using stochastic gradient descent:

inference.perform_inference(AR_model, 
                            number_iterations=500,
                            number_samples=300,
                            optimizer="SGD",
                            lr=0.001)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

brancher-0.3.5.tar.gz (37.9 kB view details)

Uploaded Source

Built Distribution

brancher-0.3.5-py3-none-any.whl (44.2 kB view details)

Uploaded Python 3

File details

Details for the file brancher-0.3.5.tar.gz.

File metadata

  • Download URL: brancher-0.3.5.tar.gz
  • Upload date:
  • Size: 37.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.29.1 CPython/3.7.2

File hashes

Hashes for brancher-0.3.5.tar.gz
Algorithm Hash digest
SHA256 a153e942a558098537ce51d34cb6792e6331d62ee75ba36ffc4b70a4a1b77331
MD5 2b7d1212262fcba39e285780642c86fa
BLAKE2b-256 f3a4655141a4913cd5e00fcf1c279d5ce4b56acbb6d43ecd46255449c62f1601

See more details on using hashes here.

File details

Details for the file brancher-0.3.5-py3-none-any.whl.

File metadata

  • Download URL: brancher-0.3.5-py3-none-any.whl
  • Upload date:
  • Size: 44.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.29.1 CPython/3.7.2

File hashes

Hashes for brancher-0.3.5-py3-none-any.whl
Algorithm Hash digest
SHA256 1fe1c23be6e68c1d8f44763c1429e5c79b2273c0092461af5def9430a9c8beb9
MD5 7046416ea30913e068e785dd98d73c17
BLAKE2b-256 9e9e9cb048f19ec0601a2c7a47c834a759b1011fe9228b471940daad1e0c7904

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page