Skip to main content

A Python package for MCMC sampling.

Project description

MCMC Samplers

A Python package for performing MCMC sampling with PyTorch.

Installation

To install this package, run the following command from the terminal:

pip install mcmc-samplers

Framework

Each sampling algorithm is run using the __call__ method inherited from the Sampler base class. The differences in implementation between the samplers are primarily determined by how the following two attributes are defined:

  • proposals
  • acceptance_kernels

Common proposals, acceptance kernels, and pairings of these two objects are described next.

Proposals

One of the most important design choices in creating an MCMC sampler is how to propose samples. Proposals should be easy to sample from, efficient to evaluate, and resemble the target distribution as closely as possible. This package implements the following proposals:

  • GaussianRandomWalk: A Gaussian random walk.
  • AdaptiveCovariance: A Gaussian random walk where the covariance adapts to match the empirical covariance of the past states of the Markov chain.
  • ScaledCovariance: A Gaussian random walk where the covariance is a scaled version of another (possibly time-varying) covariance.
  • HamiltonianDynamics: A gradient-based proposal that uses Hamiltonian dynamics to propose samples from high-density regions.

Acceptance kernels

The acceptance kernel in an MCMC scheme is used to guarantee that the Markov chain converges to the target distribution. The most common acceptance kernel comes from the Metropolis-Hastings algorithm, and most other acceptance kernels are based on the one used in that algorithm. Indeed, all of the samplers in this package use an acceptance kernel derived from the Metropolis-Hastings one.

Samplers

Any MCMC sampler can be seen simply as a proposal paired with an acceptance kernel. This package implements the following common pairings:

  • MetropolisHastings: Pairs the Metropolis-Hastings acceptance kernel with a Gaussian random walk proposal.
  • DelayedRejectionAdaptiveMetropolis: Pairs the delayed rejection algorithm using two stages with the Gaussian random walk proposal using an adaptive covariance.
  • HamiltonianMonteCarlo: Pairs the Metropolis-Hastings acceptance kernel with a Hamiltonian dynamics proposal.

Example: Banana distribution

A walkthrough of part of the banana example is given here to demonstrate how the package can be used.

First import the necessary packages.

import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from mcmc_samplers import *

Then define a function that evaluates the log probability of the (unnormalized) target distribution. In this example, a class is used.

class Banana:

    def __init__(
            self
    ):

        mean = torch.zeros(2)
        cov = torch.tensor([[1., 0.9], [0.9, 1.]])
        self.mvn = MultivariateNormal(mean, cov)

    def log_prob(
            self,
            x : torch.Tensor
    ) -> torch.Tensor:
        
        x = torch.atleast_2d(x)
        y = torch.cat((x[:,0:1], x[:,1:2] + (x[:,0:1] + 1)**2), dim=1)
        return self.mvn.log_prob(y)

Inside the main function, instantiate the Banana class and specify sampling parameters. This example uses the DRAM algorithm, so the initial sample and initial covariance for the adaptive Gaussian random walk must be specified.

target = Banana()

init_sample = torch.tensor([0.,-1.])
init_cov = torch.tensor([[1., 0.9], [0.9, 1.]])

Lastly, create the Sampler object and run for the desired number of iterations.

dram = DelayedRejectionAdaptiveMetropolis(
        target = target.log_prob,
        x0 = init_sample,
        cov = init_cov
    )

num_samples = int(1e4)
samples, log_probs = dram(num_samples)

To visualize the results, the samples can be used to create a SamplerVisualizer object. This object can then be called to plot the sample chains and 1D and 2D histograms.

import matplotlib.pyplot as plt

labels = ['$x_1$', '$x_2$']
visualizer = SampleVisualizer(samples)
visualizer.chains(labels=labels)
visualizer.triangular_hist(bins=50, labels=labels)
plt.show()

Author information

Author: Nicholas Galioto
Email: ngalioto@umich.edu
License: GPL3
Copyright 2024, Nicholas Galioto

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mcmc-samplers-0.0.2.tar.gz (51.6 kB view details)

Uploaded Source

Built Distribution

mcmc_samplers-0.0.2-py3-none-any.whl (43.5 kB view details)

Uploaded Python 3

File details

Details for the file mcmc-samplers-0.0.2.tar.gz.

File metadata

  • Download URL: mcmc-samplers-0.0.2.tar.gz
  • Upload date:
  • Size: 51.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for mcmc-samplers-0.0.2.tar.gz
Algorithm Hash digest
SHA256 23840c0cd03081f7ee5eda542ea3a692e23c8e42763cc6eb7aed9aa990700501
MD5 7dc718bcced3f1154243c1b1e5fd8231
BLAKE2b-256 7b46c81b591d0daf0d74d0f3b9795100138d2c803ea53b4cfe21b7221e11cd8b

See more details on using hashes here.

File details

Details for the file mcmc_samplers-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for mcmc_samplers-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d018d76618a21f6c9e28356df79aceae8142b8bcdee9295c10adae504db432d4
MD5 8db841127f7f72332e0ab4abd62ee7ba
BLAKE2b-256 052c8580fb44fde6064125290974fbf10c5ce19eb2dfc73394873976dbd9286a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page