A Python package for MCMC sampling.
Project description
MCMC Samplers
A Python package for performing MCMC sampling with PyTorch.
Installation
To install this package, run the following command from the terminal:
pip install mcmc-samplers
Framework
Each sampling algorithm is run using the __call__
method inherited from the Sampler
base class. The differences in implementation between the samplers are primarily determined by how the following two attributes are defined:
proposals
acceptance_kernels
Common proposals, acceptance kernels, and pairings of these two objects are described next.
Proposals
One of the most important design choices in creating an MCMC sampler is how to propose samples. Proposals should be easy to sample from, efficient to evaluate, and resemble the target distribution as closely as possible. This package implements the following proposals:
GaussianRandomWalk
: A Gaussian random walk.AdaptiveCovariance
: A Gaussian random walk where the covariance adapts to match the empirical covariance of the past states of the Markov chain.ScaledCovariance
: A Gaussian random walk where the covariance is a scaled version of another (possibly time-varying) covariance.HamiltonianDynamics
: A gradient-based proposal that uses Hamiltonian dynamics to propose samples from high-density regions.
Acceptance kernels
The acceptance kernel in an MCMC scheme is used to guarantee that the Markov chain converges to the target distribution. The most common acceptance kernel comes from the Metropolis-Hastings algorithm, and most other acceptance kernels are based on the one used in that algorithm. Indeed, all of the samplers in this package use an acceptance kernel derived from the Metropolis-Hastings one.
Samplers
Any MCMC sampler can be seen simply as a proposal paired with an acceptance kernel. This package implements the following common pairings:
MetropolisHastings
: Pairs the Metropolis-Hastings acceptance kernel with a Gaussian random walk proposal.DelayedRejectionAdaptiveMetropolis
: Pairs the delayed rejection algorithm using two stages with the Gaussian random walk proposal using an adaptive covariance.HamiltonianMonteCarlo
: Pairs the Metropolis-Hastings acceptance kernel with a Hamiltonian dynamics proposal.
Example: Banana distribution
A walkthrough of part of the banana example is given here to demonstrate how the package can be used.
First import the necessary packages.
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from mcmc_samplers import *
Then define a function that evaluates the log probability of the (unnormalized) target distribution. In this example, a class is used.
class Banana:
def __init__(
self
):
mean = torch.zeros(2)
cov = torch.tensor([[1., 0.9], [0.9, 1.]])
self.mvn = MultivariateNormal(mean, cov)
def log_prob(
self,
x : torch.Tensor
) -> torch.Tensor:
x = torch.atleast_2d(x)
y = torch.cat((x[:,0:1], x[:,1:2] + (x[:,0:1] + 1)**2), dim=1)
return self.mvn.log_prob(y)
Inside the main function, instantiate the Banana
class and specify sampling parameters. This example uses the DRAM algorithm, so the initial sample and initial covariance for the adaptive Gaussian random walk must be specified.
target = Banana()
init_sample = torch.tensor([0.,-1.])
init_cov = torch.tensor([[1., 0.9], [0.9, 1.]])
Lastly, create the Sampler
object and run for the desired number of iterations.
dram = DelayedRejectionAdaptiveMetropolis(
target = target.log_prob,
x0 = init_sample,
cov = init_cov
)
num_samples = int(1e4)
samples, log_probs = dram(num_samples)
To visualize the results, the samples can be used to create a SamplerVisualizer
object. This object can then be called to plot the sample chains and 1D and 2D histograms.
import matplotlib.pyplot as plt
labels = ['$x_1$', '$x_2$']
visualizer = SampleVisualizer(samples)
visualizer.chains(labels=labels)
visualizer.triangular_hist(bins=50, labels=labels)
plt.show()
Author information
Author: Nicholas Galioto
Email: ngalioto@umich.edu
License: GPL3
Copyright 2024, Nicholas Galioto
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mcmc-samplers-0.0.2.tar.gz
.
File metadata
- Download URL: mcmc-samplers-0.0.2.tar.gz
- Upload date:
- Size: 51.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 23840c0cd03081f7ee5eda542ea3a692e23c8e42763cc6eb7aed9aa990700501 |
|
MD5 | 7dc718bcced3f1154243c1b1e5fd8231 |
|
BLAKE2b-256 | 7b46c81b591d0daf0d74d0f3b9795100138d2c803ea53b4cfe21b7221e11cd8b |
File details
Details for the file mcmc_samplers-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: mcmc_samplers-0.0.2-py3-none-any.whl
- Upload date:
- Size: 43.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d018d76618a21f6c9e28356df79aceae8142b8bcdee9295c10adae504db432d4 |
|
MD5 | 8db841127f7f72332e0ab4abd62ee7ba |
|
BLAKE2b-256 | 052c8580fb44fde6064125290974fbf10c5ce19eb2dfc73394873976dbd9286a |