Skip to main content

Sequential Monte Carlo and particle filtering in JAX

Project description

smcjax

CI Python License Ruff Pyright

Sequential Monte Carlo and particle filtering in JAX.

smcjax extends Dynamax and BlackJAX with particle filters and Bayesian workflow diagnostics that neither library provides. All filters are JIT-compiled via jax.lax.scan and GPU-ready.

Features

  • Bootstrap (SIR) particle filter — Gordon et al. (1993)
  • Auxiliary particle filter — Pitt & Shephard (1999)
  • Liu-West filter — joint state-parameter estimation via kernel density smoothing (Liu & West, 2001)
  • Forward simulation — generate trajectories from state-space models
  • Diagnostics suite — weighted mean/variance/quantiles, parameter summaries, ESS traces, particle diversity, per-step log evidence increments, replicated log-ML, log Bayes factors, CRPS
  • 4 resampling schemes (via BlackJAX): systematic, stratified, multinomial, residual
  • Conditional resampling with configurable ESS threshold
  • All functions are jit- and vmap-compatible
  • Type annotations via jaxtyping

Installation

pip install smcjax

Or from source:

git clone https://github.com/michaelellis003/smcjax.git
cd smcjax
uv sync

Quick Example

import jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as jstats

from smcjax import bootstrap_filter, weighted_mean, log_ml_increments

# Define a 1-D linear Gaussian state space model
m0, P0 = jnp.array([0.0]), jnp.array([[1.0]])
F, Q = jnp.array([[0.9]]), jnp.array([[0.25]])
H, R = jnp.array([[1.0]]), jnp.array([[1.0]])

chol_P0 = jnp.linalg.cholesky(P0)
chol_Q = jnp.linalg.cholesky(Q)

def initial_sampler(key, n):
    return m0 + jr.normal(key, (n, 1)) @ chol_P0.T

def transition_sampler(key, state):
    mean = (F @ state[:, None]).squeeze(-1)
    return mean + jr.normal(key, (1,)) @ chol_Q.T

def log_observation_fn(emission, state):
    mean = (H @ state[:, None]).squeeze(-1)
    return jstats.multivariate_normal.logpdf(emission, mean, R)

# Simulate some data
key = jr.PRNGKey(0)
T = 100
emissions = jr.normal(key, (T, 1))

# Run the bootstrap particle filter
posterior = bootstrap_filter(
    key=jr.PRNGKey(1),
    initial_sampler=initial_sampler,
    transition_sampler=transition_sampler,
    log_observation_fn=log_observation_fn,
    emissions=emissions,
    num_particles=1_000,
)

print(f"Log marginal likelihood: {posterior.marginal_loglik:.2f}")
print(f"Particles shape: {posterior.filtered_particles.shape}")
print(f"Mean ESS: {posterior.ess.mean():.1f}")

# Diagnostics
means = weighted_mean(posterior)
increments = log_ml_increments(posterior)

Architecture

smcjax/
    __init__.py          # Public API (re-exports BlackJAX ESS & resampling)
    types.py             # PRNGKeyT, Scalar (matches Dynamax)
    containers.py        # ParticleState, ParticleFilterPosterior, LiuWestPosterior
    weights.py           # log_normalize, normalize
    bootstrap.py         # Bootstrap (SIR) particle filter
    auxiliary.py         # Auxiliary particle filter (Pitt & Shephard 1999)
    liu_west.py          # Liu-West filter for joint state-parameter estimation
    simulate.py          # Forward simulation from state-space models
    diagnostics.py       # Posterior summaries, model comparison, scoring rules

ESS and resampling (systematic, stratified, multinomial, residual) are provided by BlackJAX and re-exported from smcjax for convenience.

Cross-Validation

All filters are tested against reference libraries:

Module Reference Validation
bootstrap Dynamax Kalman filter Log-ML within 5% of exact
bootstrap particles (Chopin) Log-ML within 3 nats
auxiliary Dynamax Kalman filter Log-ML within 5% of exact
auxiliary Bootstrap (flat auxiliary = bootstrap) Log-ML within 3 nats
liu_west Auxiliary filter (fixed params) Log-ML within 5 nats

Notebooks

The notebooks/ directory contains a thesis-style Bayesian workflow reproduction using a Hidden Markov Model with unknown parameters, demonstrating the full pipeline: simulation, Liu-West filtering, parameter recovery, model comparison via log Bayes factors, and CRPS evaluation.

Roadmap

Phase What Status
1 Bootstrap particle filter Done
2 Auxiliary particle filter Done
3 Forward simulation + diagnostics Done
4 Liu-West filter + model comparison Done
5 EKF/UKF proposal particle filters Planned
6 Particle MCMC (PMMH) Planned

Development

uv sync                          # Install all deps
uv run pre-commit install        # Set up pre-commit hooks
uv run pytest -v --cov           # Run tests with coverage
uv run ruff check . --fix        # Lint
uv run pyright                   # Type check

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smcjax-1.0.0.tar.gz (258.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

smcjax-1.0.0-py3-none-any.whl (24.0 kB view details)

Uploaded Python 3

File details

Details for the file smcjax-1.0.0.tar.gz.

File metadata

  • Download URL: smcjax-1.0.0.tar.gz
  • Upload date:
  • Size: 258.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for smcjax-1.0.0.tar.gz
Algorithm Hash digest
SHA256 98d290f59d43c1214a2a076d4799cb051778f9a5553a844038b237b6549e49af
MD5 ff676c0fea4c14ce7b2673a5a2919fa1
BLAKE2b-256 dee61ca3554019733577525ef716eca59cbd03c38a60a61ae09d2ff594944900

See more details on using hashes here.

Provenance

The following attestation bundles were made for smcjax-1.0.0.tar.gz:

Publisher: release.yml on michaelellis003/smcjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file smcjax-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: smcjax-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 24.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for smcjax-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0a1927bee1d1fbffe8ad0e322fb960abad81852b347716f6a7b215bbf3468548
MD5 a4593ee95000a53ffc88172658bfc27a
BLAKE2b-256 0b8e6d9803fdc1f424c70169e7395c968600e5496fddf85ff992e4f822ed982d

See more details on using hashes here.

Provenance

The following attestation bundles were made for smcjax-1.0.0-py3-none-any.whl:

Publisher: release.yml on michaelellis003/smcjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page