Skip to main content

Write description

Project description

Brancher: An Object-Oriented Variational Probabilistic Programming Library

Brancher allows design and train differentiable Bayesian models using stochastic variational inference. Brancher is based on the deep learning framework PyTorch.

Building probabilistic models

Probabilistic models are defined symbolically. Random variables can be created as follows:

a = NormalVariable(loc = 0., scale = 1., name = 'a')
b = NormalVariable(loc = 0., scale = 1., name = 'b')

It is possible to chain together random variables by using arithmetic and mathematical functions:

c = NormalVariable(loc = a**2 + BF.sin(b), 
                   scale = BF.exp(b), 
                   name = 'a')

In this way, it is possible to create arbitrarely complex probabilistic models. It is also possible to use all the deep learning tools of PyTorch in order to define probabilistic models with deep neural networks.

Example: Autoregressive modeling

Probabilistic model

Probabilistic models are defined symbolically:

T = 20
driving_noise = 1.
measure_noise = 0.3
x0 = NormalVariable(0., driving_noise, 'x0')
y0 = NormalVariable(x0, measure_noise, 'x0')
b = LogitNormalVariable(0.5, 1., 'b')

x = [x0]
y = [y0]
x_names = ["x0"]
y_names = ["y0"]
for t in range(1,T):
    x_names.append("x{}".format(t))
    y_names.append("y{}".format(t))
    x.append(NormalVariable(b*x[t-1], driving_noise, x_names[t]))
    y.append(NormalVariable(x[t], measure_noise, y_names[t]))
AR_model = ProbabilisticModel(x + y)

Observe data

Once the probabilistic model is define, we can decide which variable is observed:

[yt.observe(data[yt][:, 0, :]) for yt in y]

Autoregressive variational distribution

The variational distribution can have an arbitrary structure:

Qb = LogitNormalVariable(0.5, 0.5, "b", learnable=True)
logit_b_post = DeterministicVariable(0., 'logit_b_post', learnable=True)
Qx = [NormalVariable(0., 1., 'x0', learnable=True)]
Qx_mean = [DeterministicVariable(0., 'x0_mean', learnable=True)]
for t in range(1, T):
    Qx_mean.append(DeterministicVariable(0., x_names[t] + "_mean", learnable=True))
    Qx.append(NormalVariable(BF.sigmoid(logit_b_post)*Qx[t-1] + Qx_mean[t], 1., x_names[t], learnable=True))
variational_posterior = ProbabilisticModel([Qb] + Qx)
model.set_posterior_model(variational_posterior)

Inference

Now that the models are spicified we can perform approximate inference using stochastic gradient descent:

inference.perform_inference(AR_model, 
                            number_iterations=500,
                            number_samples=300,
                            optimizer="SGD",
                            lr=0.001)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

brancher-0.2.0.tar.gz (29.7 kB view details)

Uploaded Source

Built Distribution

brancher-0.2.0-py3-none-any.whl (35.4 kB view details)

Uploaded Python 3

File details

Details for the file brancher-0.2.0.tar.gz.

File metadata

  • Download URL: brancher-0.2.0.tar.gz
  • Upload date:
  • Size: 29.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.21.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for brancher-0.2.0.tar.gz
Algorithm Hash digest
SHA256 58c5a1c3f7d020e811b4b6bb0902442736d332a8e0d9d38bbd09daabcd202c92
MD5 0dc5908fd205336db9e3b3c09b89e03a
BLAKE2b-256 1b58e0cef668c74c6cf5b73436a66e5c092afb01b5658db0ff1dc39a129c3c44

See more details on using hashes here.

File details

Details for the file brancher-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: brancher-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 35.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.21.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for brancher-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f797c69e061779dbad468bbaffcf68ef27ff515592cb1e5daf91998a34279b1f
MD5 3abd33fb747ff27bf92f111dbe46d5e1
BLAKE2b-256 29fbad8ce722a6251411cb667af5c2a86ea34e068e8c3f4939efbcc19e20292f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page