CUQIpy plugin for PyTorch
Project description
CUQIpy-PyTorch
CUQIpy-PyTorch is a plugin for the CUQIpy software package.
It adds a PyTorch backend to CUQIpy, allowing the user to use the PyTorch API to define models, distributions etc.
It also links to the Pyro No U-Turn Hamiltonian Monte Carlo sampler (NUTS) for efficient sampling from the joint posterior.
Installation
For optimal performance consider installing PyTorch using conda, then install CUQIpy-PyTorch using pip:
pip install cuqipy-pytorch
If PyTorch, Pyro or CUQIpy are not installed, they will be installed automatically from the above command.
Quickstart
Example for sampling from the eight schools model:
$$ \begin{align*} \mu &\sim \mathcal{N}(0, 10^2)\ \tau &\sim \log\mathcal{N}(5, 1)\ \boldsymbol \theta' &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}_m)\ \boldsymbol \theta &= \mu + \tau \boldsymbol \theta'\ \mathbf{y} &\sim \mathcal{N}(\boldsymbol \theta, \boldsymbol \sigma^2 \mathbf{I}_m) \end{align*} $$
where $\mathbf{y}\in\mathbb{R}^m$ and $\boldsymbol \sigma\in\mathbb{R}^m$ are observed data.
import torch as xp
from cuqi.distribution import JointDistribution
from cuqipy_pytorch.distribution import Gaussian, LogGaussian
from cuqipy_pytorch.sampler import NUTS
# Observations
y_obs = xp.tensor([28, 8, -3, 7, -1, 1, 18, 12], dtype=xp.float32)
σ_obs = xp.tensor([15, 10, 16, 11, 9, 11, 10, 18], dtype=xp.float32)
# Bayesian model
μ = Gaussian(0, 10**2)
τ = LogGaussian(5, 1)
θp = Gaussian(xp.zeros(8), 1)
θ = lambda μ, τ, θp: μ+τ*θp
y = Gaussian(θ, cov=σ_obs**2)
# Posterior sampling
joint = JointDistribution(μ, τ, θp, y) # Define joint distribution
posterior = joint(y=y_obs) # Define posterior distribution
sampler = NUTS(posterior) # Define sampling strategy
samples = sampler.sample(N=500, Nb=500) # Sample from posterior
# Plot posterior samples
samples["θp"].plot_violin();
print(samples["μ"].mean()) # Average effect
print(samples["τ"].mean()) # Average variance
For more examples, see the demos folder.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for CUQIpy-PyTorch-0.1.1.post0.dev7.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | f6e490c72c5a5ab508faabca77d4a30aadd2f7e5dc381aa12f7006395cc1de60 |
|
MD5 | 5bdb69554fa7c8ac1e9ffa683acb8e65 |
|
BLAKE2b-256 | b9b791929f0a5b8d45d4f977b60b6bd1a53a32ea39a720f1b8ec79eea9a9cdff |
Hashes for CUQIpy_PyTorch-0.1.1.post0.dev7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cfbe1a522ff7a731fee242bbb7192a9284b0ad8cd2f28ea52b08f1f47392e30a |
|
MD5 | aa0970d68cbdcfb984e58ed7aab609cc |
|
BLAKE2b-256 | 61eb3969491c657a9d134cd7c025a2ac19a55e7dea2c1f3186339d716950cad6 |