Skip to main content

CUQIpy plugin for PyTorch

Project description

CUQIpy-PyTorch

CUQIpy-PyTorch is a plugin for the CUQIpy software package.

It adds a PyTorch backend to CUQIpy, allowing the user to use the PyTorch API to define models, distributions etc.

It also links to the Pyro No U-Turn Hamiltonian Monte Carlo sampler (NUTS) for efficient sampling from the joint posterior.

Installation

For optimal performance consider installing PyTorch using conda, then install CUQIpy-PyTorch using pip:

pip install cuqipy-pytorch

If PyTorch, Pyro or CUQIpy are not installed, they will be installed automatically from the above command.

Quickstart

Example for sampling from the eight schools model:

$$ \begin{align*} \mu &\sim \mathcal{N}(0, 10^2)\ \tau &\sim \log\mathcal{N}(5, 1)\ \boldsymbol \theta' &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}_m)\ \boldsymbol \theta &= \mu + \tau \boldsymbol \theta'\ \mathbf{y} &\sim \mathcal{N}(\boldsymbol \theta, \boldsymbol \sigma^2 \mathbf{I}_m) \end{align*} $$

where $\mathbf{y}\in\mathbb{R}^m$ and $\boldsymbol \sigma\in\mathbb{R}^m$ are observed data.

import torch as xp
from cuqi.distribution import JointDistribution
from cuqipy_pytorch.distribution import Gaussian, Lognormal
from cuqipy_pytorch.sampler import NUTS

# Observations
y_obs = xp.tensor([28, 8, -3,  7, -1, 1,  18, 12], dtype=xp.float32)
σ_obs = xp.tensor([15, 10, 16, 11, 9, 11, 10, 18], dtype=xp.float32)

# Bayesian model
μ     = Gaussian(0, 10**2)
τ     = Lognormal(5, 1)
θp    = Gaussian(xp.zeros(8), 1)
θ     = lambda μ, τ, θp: μ+τ*θp
y     = Gaussian(θ, cov=σ_obs**2)

# Posterior sampling
joint = JointDistribution(μ, τ, θp, y)   # Define joint distribution 
posterior = joint(y=y_obs)               # Define posterior distribution
sampler = NUTS(posterior)                # Define sampling strategy
samples = sampler.sample(N=500, Nb=500)  # Sample from posterior

# Plot posterior samples
samples["θp"].plot_violin(); 
print(samples["μ"].mean()) # Average effect
print(samples["τ"].mean()) # Average variance

For more examples, see the demos folder.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cuqipy_pytorch-0.5.0.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cuqipy_pytorch-0.5.0-py3-none-any.whl (15.6 kB view details)

Uploaded Python 3

File details

Details for the file cuqipy_pytorch-0.5.0.tar.gz.

File metadata

  • Download URL: cuqipy_pytorch-0.5.0.tar.gz
  • Upload date:
  • Size: 14.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cuqipy_pytorch-0.5.0.tar.gz
Algorithm Hash digest
SHA256 badfa0fca21b509abc75e28b5c78e59382c1fa4c5fe1ac03940fe5e72cd6aa62
MD5 ba63f694445c9e6998b241ff66662030
BLAKE2b-256 021564d82767e117152ac39f6f72386078acfdbe8586b708ab46717c47c15e98

See more details on using hashes here.

File details

Details for the file cuqipy_pytorch-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: cuqipy_pytorch-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 15.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cuqipy_pytorch-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 267dd9b208aa5a874247cd57c0d241769aee7c581e6b885117816bcd90bd80e5
MD5 bdf810e00b1ca9b915dd7617bf452db8
BLAKE2b-256 acef5536cd11074d922708ba86c54f794840daf34cb4ecd847158d9e2f9fd287

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page