Skip to main content

CUQIpy plugin for PyTorch

Project description

CUQIpy-PyTorch

CUQIpy-PyTorch is a plugin for the CUQIpy software package.

It adds a PyTorch backend to CUQIpy, allowing the user to use the PyTorch API to define models, distributions etc.

It also links to the Pyro No U-Turn Hamiltonian Monte Carlo sampler (NUTS) for efficient sampling from the joint posterior.

Installation

For optimal performance consider installing PyTorch using conda, then install CUQIpy-PyTorch using pip:

pip install cuqipy-pytorch

If PyTorch, Pyro or CUQIpy are not installed, they will be installed automatically from the above command.

Quickstart

Example for sampling from the eight schools model:

$$ \begin{align*} \mu &\sim \mathcal{N}(0, 10^2)\ \tau &\sim \log\mathcal{N}(5, 1)\ \boldsymbol \theta' &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}_m)\ \boldsymbol \theta &= \mu + \tau \boldsymbol \theta'\ \mathbf{y} &\sim \mathcal{N}(\boldsymbol \theta, \boldsymbol \sigma^2 \mathbf{I}_m) \end{align*} $$

where $\mathbf{y}\in\mathbb{R}^m$ and $\boldsymbol \sigma\in\mathbb{R}^m$ are observed data.

import torch as xp
from cuqi.distribution import JointDistribution
from cuqipy_pytorch.distribution import Gaussian, Lognormal
from cuqipy_pytorch.sampler import NUTS

# Observations
y_obs = xp.tensor([28, 8, -3,  7, -1, 1,  18, 12], dtype=xp.float32)
σ_obs = xp.tensor([15, 10, 16, 11, 9, 11, 10, 18], dtype=xp.float32)

# Bayesian model
μ     = Gaussian(0, 10**2)
τ     = Lognormal(5, 1)
θp    = Gaussian(xp.zeros(8), 1)
θ     = lambda μ, τ, θp: μ+τ*θp
y     = Gaussian(θ, cov=σ_obs**2)

# Posterior sampling
joint = JointDistribution(μ, τ, θp, y)   # Define joint distribution 
posterior = joint(y=y_obs)               # Define posterior distribution
sampler = NUTS(posterior)                # Define sampling strategy
samples = sampler.sample(N=500, Nb=500)  # Sample from posterior

# Plot posterior samples
samples["θp"].plot_violin(); 
print(samples["μ"].mean()) # Average effect
print(samples["τ"].mean()) # Average variance

For more examples, see the demos folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cuqipy_pytorch-0.4.0.post0.dev5.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cuqipy_pytorch-0.4.0.post0.dev5-py3-none-any.whl (15.7 kB view details)

Uploaded Python 3

File details

Details for the file cuqipy_pytorch-0.4.0.post0.dev5.tar.gz.

File metadata

File hashes

Hashes for cuqipy_pytorch-0.4.0.post0.dev5.tar.gz
Algorithm Hash digest
SHA256 7a212cc5d4199df6fb2c2f8677a7772f6d068e7a7978fef8556236b24b733910
MD5 14afac9b16e6ab3a9d1a1467d4beb409
BLAKE2b-256 668ee2a680573e13f150a969de9cd9d377a379afb9a102a0acbd1716b7fa6fb6

See more details on using hashes here.

File details

Details for the file cuqipy_pytorch-0.4.0.post0.dev5-py3-none-any.whl.

File metadata

File hashes

Hashes for cuqipy_pytorch-0.4.0.post0.dev5-py3-none-any.whl
Algorithm Hash digest
SHA256 2496c59e902bd027a1b3d29581b76f48a3f437b343d31b3f297c3e4ad625e33b
MD5 f16e2a37fa39d66aaf5046fb46615262
BLAKE2b-256 60616d44d1bc0690a7e19002d6d0de6a2a809e2a8544d0a92dc69ccacbd135d8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page