Skip to main content

A minimal preconditioned Crank-Nicolson MCMC sampler

Project description

minipcn

DOI

A minimal implementation of preconditioned Crank-Nicolson MCMC sampling.

Installation

minipcn can be installed from PyPI using pip:

pip install minipcn

Usage

The basic usage is:

from minipcn import Sampler
import numpy as np

log_prob_fn = ...    # Log-probability function - must be vectorized
dims = ...    # The number of dimensions
rng = np.random.default_rng(42)

sampler = Sampler(
    log_prob_fn=log_prob_fn,
    dims=dims,
    step_fn="pcn",  # Or "tpcn"
)

x0 = rng.normal(size=(100, dims))

chain, history = sampler.sample(x0, n_steps=500, rng=rng)

For a complete example, see the examples directory.

Array API support

minipcn also supports different array API backends via array-api-compat and orng for random number generation.

Usage is then similar to when using numpy, except one must use the RNG from orng and specify the backend via xp:

from minipcn import Sampler
from orng import RandomGenerator
import torch

log_prob_fn = ...    # Log-probability function - must be vectorized
dims = ...    # The number of dimensions
rng = RandomGenerator(backend="torch", seed=42)

sampler = Sampler(
    log_prob_fn=log_prob_fn,
    dims=dims,
    step_fn="pcn",    # Or tpcn
    xp=torch,
)

# Generate initial samples
x0 = rng.randn(size=(100, dims))

# Run the sampler
chain, history = sampler.sample(x0, n_steps=500, rng=rng)

Note: the tpCN step falls back to numpy for fitting the Student-t distribution

Functional API

minipcn also supports explicit functional RNG state via Sampler.sample_functional(...). This is the path to use for JAX compilation or any workflow where RNG state must be threaded explicitly.

The functional API does not take an RNG object but a backend and state:

import jax
import jax.numpy as jnp
from minipcn import Sampler
from orng.functional import create_functional_backend

dims = 4
rng_backend = create_functional_backend("jax")
rng_state = rng_backend.init_state(seed=42, generator=None)
x0, rng_state = rng_backend.normal(
    rng_state,
    loc=0.0,
    scale=1.0,
    size=(32, dims),
    dtype=jnp.float32,
)

def log_prob_fn(x):
    return -0.5 * jnp.sum(x**2, axis=-1)

sampler = Sampler(
    log_prob_fn=log_prob_fn,
    dims=dims,
    step_fn="pcn",
    xp=jnp,
)

samples, history, next_rng_state = sampler.sample_functional(
    x0,
    n_steps=8,
    rng_state=rng_state,
    verbose=False,
    return_last_only=True,
)

sample_functional(...) returns (chain, history, next_rng_state).

To use it under jax.jit, thread the state through the compiled function:

@jax.jit
def
run(x, state):
    samples, history, next_state = sampler.sample_functional(
        x,
        n_steps=8,
        rng_state=state,
        verbose=False,
        return_last_only=True,
    )
    return samples, history, next_state

samples, history, rng_state = run(x0, rng_state)

The backend for sample_functional(...) is inferred from xp. For example:

  • xp=np uses the NumPy functional backend
  • xp=jax.numpy uses the JAX functional backend
  • xp=torch uses the PyTorch functional backend

Use sample(...) for stateful RNG objects and sample_functional(...) when you want explicit RNG state.

Citing minipcn

If you use minipcn in your work, please cite our DOI

If using the tpcn kernel, please also cite Grumitt et al

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minipcn-0.2.0a4.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

minipcn-0.2.0a4-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file minipcn-0.2.0a4.tar.gz.

File metadata

  • Download URL: minipcn-0.2.0a4.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for minipcn-0.2.0a4.tar.gz
Algorithm Hash digest
SHA256 9515c8a0d523baa6c5a96db0e634c748a00f015d1d21d6252eeff45f982d5b9a
MD5 e4a61479d33bbfb3514d261a32052d57
BLAKE2b-256 c95a83d75ad50094c928a1fed2fd4bab7aa8adf3c8d738307f16d5bec3650739

See more details on using hashes here.

Provenance

The following attestation bundles were made for minipcn-0.2.0a4.tar.gz:

Publisher: publish.yml on mj-will/minipcn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file minipcn-0.2.0a4-py3-none-any.whl.

File metadata

  • Download URL: minipcn-0.2.0a4-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for minipcn-0.2.0a4-py3-none-any.whl
Algorithm Hash digest
SHA256 edb0a58d4c8f698a402e93308b0ea942a52ed94e478e45c46f31c32fd431747a
MD5 6eb9083bc27194f95a15fc7f37b068ea
BLAKE2b-256 657af10c71eb6772d9bf7a16c7a1fbaa90565758ce4f1cb83c280cec3493f737

See more details on using hashes here.

Provenance

The following attestation bundles were made for minipcn-0.2.0a4-py3-none-any.whl:

Publisher: publish.yml on mj-will/minipcn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page