Skip to main content

CUQIpy plugin for PyTorch

Project description

CUQIpy-PyTorch

CUQIpy-PyTorch is a plugin for the CUQIpy software package.

It adds a PyTorch backend to CUQIpy, allowing the user to use the PyTorch API to define models, distributions etc.

It also links to the Pyro No U-Turn Hamiltonian Monte Carlo sampler (NUTS) for efficient sampling from the joint posterior.

Installation

For optimal performance consider installing PyTorch using conda, then install CUQIpy-PyTorch using pip:

pip install cuqipy-pytorch

If PyTorch, Pyro or CUQIpy are not installed, they will be installed automatically from the above command.

Quickstart

Example for sampling from the eight schools model:

$$ \begin{align*} \mu &\sim \mathcal{N}(0, 10^2)\ \tau &\sim \log\mathcal{N}(5, 1)\ \boldsymbol \theta' &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}_m)\ \boldsymbol \theta &= \mu + \tau \boldsymbol \theta'\ \mathbf{y} &\sim \mathcal{N}(\boldsymbol \theta, \boldsymbol \sigma^2 \mathbf{I}_m) \end{align*} $$

where $\mathbf{y}\in\mathbb{R}^m$ and $\boldsymbol \sigma\in\mathbb{R}^m$ are observed data.

import torch as xp
from cuqi.distribution import JointDistribution
from cuqipy_pytorch.distribution import Gaussian, Lognormal
from cuqipy_pytorch.sampler import NUTS

# Observations
y_obs = xp.tensor([28, 8, -3,  7, -1, 1,  18, 12], dtype=xp.float32)
σ_obs = xp.tensor([15, 10, 16, 11, 9, 11, 10, 18], dtype=xp.float32)

# Bayesian model
μ     = Gaussian(0, 10**2)
τ     = Lognormal(5, 1)
θp    = Gaussian(xp.zeros(8), 1)
θ     = lambda μ, τ, θp: μ+τ*θp
y     = Gaussian(θ, cov=σ_obs**2)

# Posterior sampling
joint = JointDistribution(μ, τ, θp, y)   # Define joint distribution 
posterior = joint(y=y_obs)               # Define posterior distribution
sampler = NUTS(posterior)                # Define sampling strategy
samples = sampler.sample(N=500, Nb=500)  # Sample from posterior

# Plot posterior samples
samples["θp"].plot_violin(); 
print(samples["μ"].mean()) # Average effect
print(samples["τ"].mean()) # Average variance

For more examples, see the demos folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cuqipy_pytorch-0.4.0.post0.dev7.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cuqipy_pytorch-0.4.0.post0.dev7-py3-none-any.whl (15.7 kB view details)

Uploaded Python 3

File details

Details for the file cuqipy_pytorch-0.4.0.post0.dev7.tar.gz.

File metadata

File hashes

Hashes for cuqipy_pytorch-0.4.0.post0.dev7.tar.gz
Algorithm Hash digest
SHA256 fd410300dc9638dccc4c4b0fb36daeb05ebec066f62ae52332308e71d1be0716
MD5 bfb73064da41db11434b9d6f85eb35e7
BLAKE2b-256 b4e4ec19936a0e27ef6bdaf3b3455723f4dc0ae651ac31cb54daff8580010700

See more details on using hashes here.

File details

Details for the file cuqipy_pytorch-0.4.0.post0.dev7-py3-none-any.whl.

File metadata

File hashes

Hashes for cuqipy_pytorch-0.4.0.post0.dev7-py3-none-any.whl
Algorithm Hash digest
SHA256 5c1e2ec0e56f77228b0b522853ca4ee0a9a7414b487ad0f0248291b4262ea1c3
MD5 247a7e7819d2e7db82fb1d3320f3cfbd
BLAKE2b-256 8b56b901199aaf56e0abeeabe5a38906a681feaa7a22e0dddf5eef5c7c3f4ebf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page