Skip to main content

A user-centered Python package for differentiable probabilistic inference

Project description

Brancher: A user-centered Python package for differentiable probabilistic inference

Brancher allows to design and train differentiable Bayesian models using stochastic variational inference. Brancher is based on the deep learning framework PyTorch.

Building probabilistic models

Probabilistic models are defined symbolically. Random variables can be created as follows:

a = NormalVariable(loc = 0., scale = 1., name = 'a')
b = NormalVariable(loc = 0., scale = 1., name = 'b')

It is possible to chain together random variables by using arithmetic and mathematical functions:

c = NormalVariable(loc = a**2 + BF.sin(b), 
                   scale = BF.exp(b), 
                   name = 'a')

In this way, it is possible to create arbitrarely complex probabilistic models. It is also possible to use all the deep learning tools of PyTorch in order to define probabilistic models with deep neural networks.

Example: Autoregressive modeling

Probabilistic model

Probabilistic models are defined symbolically:

T = 20
driving_noise = 1.
measure_noise = 0.3
x0 = NormalVariable(0., driving_noise, 'x0')
y0 = NormalVariable(x0, measure_noise, 'x0')
b = LogitNormalVariable(0.5, 1., 'b')

x = [x0]
y = [y0]
x_names = ["x0"]
y_names = ["y0"]
for t in range(1,T):
    x_names.append("x{}".format(t))
    y_names.append("y{}".format(t))
    x.append(NormalVariable(b*x[t-1], driving_noise, x_names[t]))
    y.append(NormalVariable(x[t], measure_noise, y_names[t]))
AR_model = ProbabilisticModel(x + y)

Observe data

Once the probabilistic model is define, we can decide which variable is observed:

[yt.observe(data[yt][:, 0, :]) for yt in y]

Autoregressive variational distribution

The variational distribution can have an arbitrary structure:

Qb = LogitNormalVariable(0.5, 0.5, "b", learnable=True)
logit_b_post = DeterministicVariable(0., 'logit_b_post', learnable=True)
Qx = [NormalVariable(0., 1., 'x0', learnable=True)]
Qx_mean = [DeterministicVariable(0., 'x0_mean', learnable=True)]
for t in range(1, T):
    Qx_mean.append(DeterministicVariable(0., x_names[t] + "_mean", learnable=True))
    Qx.append(NormalVariable(BF.sigmoid(logit_b_post)*Qx[t-1] + Qx_mean[t], 1., x_names[t], learnable=True))
variational_posterior = ProbabilisticModel([Qb] + Qx)
model.set_posterior_model(variational_posterior)

Inference

Now that the models are spicified we can perform approximate inference using stochastic gradient descent:

inference.perform_inference(AR_model, 
                            number_iterations=500,
                            number_samples=300,
                            optimizer="SGD",
                            lr=0.001)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

brancher-0.3.0.tar.gz (30.0 kB view details)

Uploaded Source

Built Distribution

brancher-0.3.0-py3-none-any.whl (35.7 kB view details)

Uploaded Python 3

File details

Details for the file brancher-0.3.0.tar.gz.

File metadata

  • Download URL: brancher-0.3.0.tar.gz
  • Upload date:
  • Size: 30.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.21.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for brancher-0.3.0.tar.gz
Algorithm Hash digest
SHA256 5164a5299848bb5c6a8507e0f089981fa32816bb0933ed61c20309cd32105f95
MD5 480bce9dc407769640653fcc800bc964
BLAKE2b-256 027fc676e00d01fb2c164395cf4463e8457f5a9db0516c8df2fa402f2304fe45

See more details on using hashes here.

File details

Details for the file brancher-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: brancher-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 35.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.21.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for brancher-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cb749e33b4bab7a19f3c3419e1081d70595aa0636d99b92498f105e287369cf2
MD5 174671788f187f0e7af527728c4da165
BLAKE2b-256 cc3f1e77914bd16a07965cd1f023290227b6b5c2e32ab5aa1946109e4e76e05f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page