Skip to main content

A user-centered Python package for differentiable probabilistic inference

Project description

Brancher: An Object-Oriented Variational Probabilistic Programming Library

Brancher allows design and train differentiable Bayesian models using stochastic variational inference. Brancher is based on the deep learning framework PyTorch.

Building probabilistic models

Probabilistic models are defined symbolically. Random variables can be created as follows:

a = NormalVariable(loc = 0., scale = 1., name = 'a')
b = NormalVariable(loc = 0., scale = 1., name = 'b')

It is possible to chain together random variables by using arithmetic and mathematical functions:

c = NormalVariable(loc = a**2 + BF.sin(b), 
                   scale = BF.exp(b), 
                   name = 'a')

In this way, it is possible to create arbitrarely complex probabilistic models. It is also possible to use all the deep learning tools of PyTorch in order to define probabilistic models with deep neural networks.

Example: Autoregressive modeling

Probabilistic model

Probabilistic models are defined symbolically:

T = 20
driving_noise = 1.
measure_noise = 0.3
x0 = NormalVariable(0., driving_noise, 'x0')
y0 = NormalVariable(x0, measure_noise, 'x0')
b = LogitNormalVariable(0.5, 1., 'b')

x = [x0]
y = [y0]
x_names = ["x0"]
y_names = ["y0"]
for t in range(1,T):
    x_names.append("x{}".format(t))
    y_names.append("y{}".format(t))
    x.append(NormalVariable(b*x[t-1], driving_noise, x_names[t]))
    y.append(NormalVariable(x[t], measure_noise, y_names[t]))
AR_model = ProbabilisticModel(x + y)

Observe data

Once the probabilistic model is define, we can decide which variable is observed:

[yt.observe(data[yt][:, 0, :]) for yt in y]

Autoregressive variational distribution

The variational distribution can have an arbitrary structure:

Qb = LogitNormalVariable(0.5, 0.5, "b", learnable=True)
logit_b_post = DeterministicVariable(0., 'logit_b_post', learnable=True)
Qx = [NormalVariable(0., 1., 'x0', learnable=True)]
Qx_mean = [DeterministicVariable(0., 'x0_mean', learnable=True)]
for t in range(1, T):
    Qx_mean.append(DeterministicVariable(0., x_names[t] + "_mean", learnable=True))
    Qx.append(NormalVariable(BF.sigmoid(logit_b_post)*Qx[t-1] + Qx_mean[t], 1., x_names[t], learnable=True))
variational_posterior = ProbabilisticModel([Qb] + Qx)
model.set_posterior_model(variational_posterior)

Inference

Now that the models are spicified we can perform approximate inference using stochastic gradient descent:

inference.perform_inference(AR_model, 
                            number_iterations=500,
                            number_samples=300,
                            optimizer="SGD",
                            lr=0.001)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

brancher-0.2.1.tar.gz (28.7 kB view details)

Uploaded Source

Built Distribution

brancher-0.2.1-py3-none-any.whl (35.4 kB view details)

Uploaded Python 3

File details

Details for the file brancher-0.2.1.tar.gz.

File metadata

  • Download URL: brancher-0.2.1.tar.gz
  • Upload date:
  • Size: 28.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.21.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for brancher-0.2.1.tar.gz
Algorithm Hash digest
SHA256 e5e1722e5290e4414386f8ccaa1d7e54c5530e905dce99a25706f18a1a568ddf
MD5 ef3ee9fb25ec30dfccd54f1f9bd155b3
BLAKE2b-256 ebcc0f5fac9eefb59df7e00cd14cd34c558d95425f9423ec32ad0b1c55962c7a

See more details on using hashes here.

File details

Details for the file brancher-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: brancher-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 35.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.21.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for brancher-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 035754109184ea26145f45907e5f6dac5e1adde798cf49abdf624275018ff537
MD5 3cfbf2e199cc8a05185714fb9f3f3e6f
BLAKE2b-256 d0b1478d0a1e0c696f27d10e8bd481e934d611692fde2a207864d1856b257bd3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page