Skip to main content

CUQIpy plugin for PyTorch

Project description

CUQIpy-PyTorch

CUQIpy-PyTorch is a plugin for the CUQIpy software package.

It adds a PyTorch backend to CUQIpy, allowing the user to use the PyTorch API to define models, distributions etc.

It also links to the Pyro No U-Turn Hamiltonian Monte Carlo sampler (NUTS) for efficient sampling from the joint posterior.

Installation

For optimal performance consider installing PyTorch using conda, then install CUQIpy-PyTorch using pip:

pip install cuqipy-pytorch

If PyTorch, Pyro or CUQIpy are not installed, they will be installed automatically from the above command.

Quickstart

Example for sampling from the eight schools model:

$$ \begin{align*} \mu &\sim \mathcal{N}(0, 10^2)\ \tau &\sim \log\mathcal{N}(5, 1)\ \boldsymbol \theta' &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}_m)\ \boldsymbol \theta &= \mu + \tau \boldsymbol \theta'\ \mathbf{y} &\sim \mathcal{N}(\boldsymbol \theta, \boldsymbol \sigma^2 \mathbf{I}_m) \end{align*} $$

where $\mathbf{y}\in\mathbb{R}^m$ and $\boldsymbol \sigma\in\mathbb{R}^m$ are observed data.

import torch as xp
from cuqi.distribution import JointDistribution
from cuqipy_pytorch.distribution import Gaussian, Lognormal
from cuqipy_pytorch.sampler import NUTS

# Observations
y_obs = xp.tensor([28, 8, -3,  7, -1, 1,  18, 12], dtype=xp.float32)
σ_obs = xp.tensor([15, 10, 16, 11, 9, 11, 10, 18], dtype=xp.float32)

# Bayesian model
μ     = Gaussian(0, 10**2)
τ     = Lognormal(5, 1)
θp    = Gaussian(xp.zeros(8), 1)
θ     = lambda μ, τ, θp: μ+τ*θp
y     = Gaussian(θ, cov=σ_obs**2)

# Posterior sampling
joint = JointDistribution(μ, τ, θp, y)   # Define joint distribution 
posterior = joint(y=y_obs)               # Define posterior distribution
sampler = NUTS(posterior)                # Define sampling strategy
samples = sampler.sample(N=500, Nb=500)  # Sample from posterior

# Plot posterior samples
samples["θp"].plot_violin(); 
print(samples["μ"].mean()) # Average effect
print(samples["τ"].mean()) # Average variance

For more examples, see the demos folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

CUQIpy-PyTorch-0.4.0.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

CUQIpy_PyTorch-0.4.0-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file CUQIpy-PyTorch-0.4.0.tar.gz.

File metadata

  • Download URL: CUQIpy-PyTorch-0.4.0.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for CUQIpy-PyTorch-0.4.0.tar.gz
Algorithm Hash digest
SHA256 0217564ff2dc4de5d8668eb6c09b2ab05b94ba7941458114c430f87729acba25
MD5 231669aeacf1e592a4bd99366984255b
BLAKE2b-256 7954a54b3f5cab8b81878c3add233f032a58e86b2081d134efc3ef82aad55151

See more details on using hashes here.

File details

Details for the file CUQIpy_PyTorch-0.4.0-py3-none-any.whl.

File metadata

File hashes

Hashes for CUQIpy_PyTorch-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 03e7d9f9669ceb3573026710e0f7a4b8bf98e33ac560394260f582938f121c16
MD5 c3e338bbc5b5f073e2e20510bd60fe5e
BLAKE2b-256 34edd1784679806c76eb2d51ff1e9eb0cb955b407624a1389ca9caf293217ec2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page