Delayed Acceptance MCMC Sampler
Project description
tinyDA
Delayed Acceptance (Christen & Fox, 2005) MCMC sampler with finite-length subchain sampling and adaptive error modelling.
This is intended as a simple, lightweight implementation, with minimal dependencies, i.e. nothing beyond the SciPy stack.
It is fully imperative and easy to use!
Installation
tinyDA can be installed from PyPI:
pip install tinyDA
Features
Proposals
- Random Walk Metropolis Hastings (RWMH) - Metropolis et al. (1953), Hastings (1970)
- preconditioned Crank-Nicolson (pCN) - Cotter et al. (2013)
- Adaptive Metropolis (AM) - Haario et al. (2001)
- Adaptive pCN - Hu and Yao (2016)
- DREAM(Z) - Vrugt (2016)
- Multiple-Try Metropolis (MTM) - Liu et al. (2000)
Adaptive Error Models
- State independent - Cui et al. (2018)
- State dependent - Cui et al. (2018)
Diagnostics
- A bunch of plotting functions
- Rank-normalised split- and ESS - Vehtari et al. (2020)
Dependencies:
Usage
A few illustrative examples are available as Jupyter Notebooks in the root directory. Below is a short summary of the core features.
Distributions
The prior and likelihood can be defined using standard scipy.stats
classes:
import tinyDA as tda
from scipy.stats import multivariate_normal
mean_prior = np.zeros(n_dim)
cov_prior = np.eye(n_dim)
cov_likelihood = sigma**2*np.eye(data.shape[0])
my_prior = multivariate_normal(mean_prior, cov_prior)
my_loglike = tda.LogLike(data, cov_likelihood)
If using a Gaussian likelihood, we recommend using the tinyDA
implementation, since it is unnormalised and plays along well with tda.AdaptiveLogLike
used for the Adaptive Error Model. Home-brew distributions can easily be defined, and must have a .rvs()
method for drawing random samples and a logpdf(x)
method for computing the log-likelihood, as per the SciPy
implementation.
tinyDA.LinkFactory
At the heart of the TinyDA sampler sits what we call a LinkFactory
, which is responsible for:
- Calling the model with some parameters (a proposal) and collecting the model output.
- Evaluating the prior density of the parameters, and the likelihood of the model output, given the parameters.
- Constructing
tda.Link
instances that hold information for each sample.
The LinkFactory
must be defined by inheritance from either tda.LinkFactory
or tda.BlackBoxLinkFactory
. The former allows for computing the model output directly from the input parameters, using pure Python or whichever external library you want to call. The evaluate_model()
method must thus be overwritten:
class MyLinkFactory(tda.LinkFactory):
def evaluate_model(self, parameters):
output = parameters[0] + parameters[1]*x
qoi = None
return output, qoi
my_link_factory = MyLinkFactory(my_prior, my_loglike)
The latter allows for feeding some model object and some datapoints to the LinkFactory
at initialisation, which are then assigned as class attributes. This is useful for e.g. PDE solvers. The evaluate_model()
method must be overwritten.
class MyLinkFactory(tda.BlackBoxLinkFactory):
def evaluate_model(self, parameters):
self.model.solve(parameters)
output = self.model.get_data(self.datapoints)
if self.get_qoi:
qoi = self.model.get_qoi()
else:
qoi = None
return output, qoi
my_link_factory = MyLinkFactory(my_model, my_datapoints, my_prior, my_loglike, get_qoi=True)
Proposals
A proposal is simply initialised with its parameters:
am_cov = np.eye(n_dim)
am_t0 = 1000
am_sd = 1
am_epsilon = 1e-6
my_proposal = tda.AdaptiveMetropolis(C0=am_cov, t0=am_t0, sd=am_sd, epsilon=am_epsilon)
Sampling
The Delayed Acceptance sampler can then be initalised and run, simply with:
my_chain = tda.DAChain(my_link_factory_coarse, my_link_factory_fine, my_proposal, subsampling_rate)
my_chain.sample(n_samples)
If you decide you need more samples, you can just call tda.DAChain.sample()
again, since all samples and tuning parameters are cached:
my_chain.sample(additional_n_samples)
Postprocessing
The entire sampling history is then stored in my_chain
, and you can extract an array of samples by doing:
samples_fine = tda.get_parameters(my_chain)
samples_coarse = tda.get_parameters(my_chain, level='coarse')
Some diagnostics are available in the diagnostics module. Please refer to their respective docstrings for usage instructions.
TODO
Parallel multi-chain samplingPopulation-based proposals- Multilevel Delayed Acceptance
- More user-friendly diagnostics
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
File details
Details for the file tinyDA-0.9.0.tar.gz
.
File metadata
- Download URL: tinyDA-0.9.0.tar.gz
- Upload date:
- Size: 37.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 47586962dee6134ce309b587a925a322b9e217e5e838959ae9906ed4d287f5b1 |
|
MD5 | 2c7c1638dbc64dafec46c754fe4c3e19 |
|
BLAKE2b-256 | 732efb0176e9cd212e7a6b24b3a86174b9eb42b5f3df295b10460ccdba8c5d4a |
Provenance
File details
Details for the file tinyDA-0.9.0-py3-none-any.whl
.
File metadata
- Download URL: tinyDA-0.9.0-py3-none-any.whl
- Upload date:
- Size: 37.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7506e8998c3dc0c899b216c0c24d83dcdf9c9ad0af93f5b99c2ebd9693579e44 |
|
MD5 | ae39b0708ed030421ef842dd329c1b2d |
|
BLAKE2b-256 | 2543bb76a92a4213529697049d86c70c47d889173fac62562208b801987a7771 |
Provenance
File details
Details for the file tinyDA-0.9.0-1-py3-none-any.whl
.
File metadata
- Download URL: tinyDA-0.9.0-1-py3-none-any.whl
- Upload date:
- Size: 37.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.54.1 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0316a066148de1aabfd602d40ec770982c142a00a6709d3059e16a5293bde960 |
|
MD5 | 1427359c7788ae1f3488c107c48afeb0 |
|
BLAKE2b-256 | 5cdb36601ba1ab25d7ed6c31adb20c0b7e284222cd50612b256f718c068aeeb8 |