Skip to main content

A PyMC3-like Interface for Pyro Stochastic Functions

Project description

PyMC3-like abstractions for pyro's stochastic function. Define a model as a stochastic function in pyro. Use pm_like wrapper to create a PyMC3-esque Model. Random variables are exposed to user as attributes of Model. pm-pyro provides abstractions for inference (NUTS : No-U-Turn Sampler), trace plots, posterior plot and posterior predictive plots.

Install

Install from pypi

pip install pm-pyro

Developer setup

# install requirements
pip install -r requirements-dev.txt
# run tests
python -m pytest pmpyro/tests.py

Example

Borrowed this example from a PyMC3 tutorial. Outcome variables Y is dependent on 2 features X_1 and X_2. The notebook for this example is available here

Model Specification

We design a simple Bayesian Linear Regression model.

Stochastic Function

The model specification is implemented as a stochastic function.

import pyro.distributions as dist
import pyro
import torch

def pyro_model(x1, x2, y):
    alpha = pyro.sample('alpha', dist.Normal(0, 10))
    beta = pyro.sample('beta',pdist.Normal(torch.zeros(2,), torch.ones(2,) * 10.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    # Expected value of outcome
    mu = alpha + beta[0] * x1 + beta[1] * x2

    # Likelihood (sampling distribution) of observations
    return pyro.sample('y_obs', dist.Normal(mu, sigma), obs=y)

Context-manager Syntax

The pm_like wrapper creates a PyMC3-esque Model. We can use the context manager syntax for running inference. pm.sample samples from the model using the NUTS sampler. The trace is a python dictionary which contains the samples.

from pmpyro import pm_like
import pmpyro as pm

with pm_like(pyro_model, X1, X2, Y) as model:
    trace = pm.sample(1000)
sample: 100%|██████████| 1300/1300 [00:16, 80.42it/s, step size=7.49e-01, acc. prob=0.911] 

Traceplot

We can visualize the samples using traceplot. Select random variables by passing them as a list via var_names = [ 'alpha' ... ] argument.

pm.traceplot(trace)

Plot Posterior

Visualize posterior of random variables using plot_posterior.

pm.plot_posterior(trace, var_names=['beta'])

Posterior Predictive Samples

We can sample from the posterior by running plot_posterior_predictive or sample_posterior_predictive with the same function signatures as the stochastic function def pyro_model(x1, x2, y), replacing observed variable Y with None.

ppc = pm.plot_posterior_predictive(X1, X2, None,
                          trace=trace, model=model, samples=60,
                          alpha=0.08, obs={'y_obs' : Y})

Trace Summary

The summary of random variables is available as a pandas array.

pm.summary()

License

This project is licensed under the GPL v3 License - see the LICENSE.md file for details

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pm-pyro-0.3.3.tar.gz (8.9 kB view details)

Uploaded Source

Built Distribution

pm_pyro-0.3.3-py3-none-any.whl (23.1 kB view details)

Uploaded Python 3

File details

Details for the file pm-pyro-0.3.3.tar.gz.

File metadata

  • Download URL: pm-pyro-0.3.3.tar.gz
  • Upload date:
  • Size: 8.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for pm-pyro-0.3.3.tar.gz
Algorithm Hash digest
SHA256 0b9841d7363c1805a7794a20ca342ce3bf67f13d56beb7747ea0ff0c245f1297
MD5 dff7373e4b90244f05c2899b3ee9745f
BLAKE2b-256 8f8807fc239ccbbabe5cc79180ed6e4da77227c70fc9de163a661e61dee3467b

See more details on using hashes here.

File details

Details for the file pm_pyro-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: pm_pyro-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 23.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for pm_pyro-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c7c6475fbc62b486c2dca5fde9c1369546ab97601433663291fd642f916226ae
MD5 c011194288e08d60ff3bdb67e3ce764d
BLAKE2b-256 207d8c35154eeecd3c05079acdfe94ac3fea0089734ff8f56819027de0c27ad7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page