CUQIpy plugin for PyTorch
Project description
CUQIpy-PyTorch
CUQIpy-PyTorch is a plugin for the CUQIpy software package.
It adds a PyTorch backend to CUQIpy, allowing the user to use the PyTorch API to define models, distributions etc.
It also links to the Pyro No U-Turn Hamiltonian Monte Carlo sampler (NUTS) for efficient sampling from the joint posterior.
Installation
For optimal performance consider installing PyTorch using conda, then install CUQIpy-PyTorch using pip:
pip install cuqipy-pytorch
If PyTorch, Pyro or CUQIpy are not installed, they will be installed automatically from the above command.
Quickstart
Example for sampling from the eight schools model:
$$ \begin{align*} \mu &\sim \mathcal{N}(0, 10^2)\ \tau &\sim \log\mathcal{N}(5, 1)\ \boldsymbol \theta' &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}_m)\ \boldsymbol \theta &= \mu + \tau \boldsymbol \theta'\ \mathbf{y} &\sim \mathcal{N}(\boldsymbol \theta, \boldsymbol \sigma^2 \mathbf{I}_m) \end{align*} $$
where $\mathbf{y}\in\mathbb{R}^m$ and $\boldsymbol \sigma\in\mathbb{R}^m$ are observed data.
import torch as xp
from cuqi.distribution import JointDistribution
from cuqipy_pytorch.distribution import Gaussian, Lognormal
from cuqipy_pytorch.sampler import NUTS
# Observations
y_obs = xp.tensor([28, 8, -3, 7, -1, 1, 18, 12], dtype=xp.float32)
σ_obs = xp.tensor([15, 10, 16, 11, 9, 11, 10, 18], dtype=xp.float32)
# Bayesian model
μ = Gaussian(0, 10**2)
τ = Lognormal(5, 1)
θp = Gaussian(xp.zeros(8), 1)
θ = lambda μ, τ, θp: μ+τ*θp
y = Gaussian(θ, cov=σ_obs**2)
# Posterior sampling
joint = JointDistribution(μ, τ, θp, y) # Define joint distribution
posterior = joint(y=y_obs) # Define posterior distribution
sampler = NUTS(posterior) # Define sampling strategy
samples = sampler.sample(N=500, Nb=500) # Sample from posterior
# Plot posterior samples
samples["θp"].plot_violin();
print(samples["μ"].mean()) # Average effect
print(samples["τ"].mean()) # Average variance
For more examples, see the demos folder.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file CUQIpy-PyTorch-0.4.0.tar.gz
.
File metadata
- Download URL: CUQIpy-PyTorch-0.4.0.tar.gz
- Upload date:
- Size: 14.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0217564ff2dc4de5d8668eb6c09b2ab05b94ba7941458114c430f87729acba25 |
|
MD5 | 231669aeacf1e592a4bd99366984255b |
|
BLAKE2b-256 | 7954a54b3f5cab8b81878c3add233f032a58e86b2081d134efc3ef82aad55151 |
File details
Details for the file CUQIpy_PyTorch-0.4.0-py3-none-any.whl
.
File metadata
- Download URL: CUQIpy_PyTorch-0.4.0-py3-none-any.whl
- Upload date:
- Size: 15.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 03e7d9f9669ceb3573026710e0f7a4b8bf98e33ac560394260f582938f121c16 |
|
MD5 | c3e338bbc5b5f073e2e20510bd60fe5e |
|
BLAKE2b-256 | 34edd1784679806c76eb2d51ff1e9eb0cb955b407624a1389ca9caf293217ec2 |