Skip to main content

Sequential Monte Carlo and particle filtering in JAX

Project description

smcjax

CI PyPI License

Sequential Monte Carlo and particle filtering in JAX.

smcjax is a JAX implementation of the methods developed in my master's thesis on sequential inference for Hidden Markov Models (University of Arkansas, 2018). It extends Dynamax and BlackJAX with particle filters and Bayesian workflow diagnostics. All filters are JIT-compiled via jax.lax.scan and GPU-ready.

Features

  • Bootstrap (SIR) particle filter — Gordon et al. (1993)
  • Auxiliary particle filter — Pitt & Shephard (1999)
  • Liu-West filter — joint state-parameter estimation via kernel density smoothing (Liu & West, 2001)
  • Forward simulation — generate trajectories from state-space models
  • Diagnostics — weighted mean/variance/quantiles, parameter summaries, ESS traces, particle diversity, per-step log evidence increments, replicated log-ML, log Bayes factors, CRPS
  • 4 resampling schemes (via BlackJAX): systematic, stratified, multinomial, residual
  • Conditional resampling with configurable ESS threshold
  • All functions are jit- and vmap-compatible
  • Type annotations via jaxtyping

Requirements

  • Python 3.10 or later
  • uv installed

Installation

pip install smcjax

Or from source:

git clone https://github.com/michaelellis003/smcjax.git
cd smcjax
uv sync

Install the pre-commit hooks (one-time setup):

uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
uv run pre-commit install --hook-type pre-push

Quick example

import jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as jstats

from smcjax import bootstrap_filter, weighted_mean, log_ml_increments

# Define a 1-D linear Gaussian state space model
m0, P0 = jnp.array([0.0]), jnp.array([[1.0]])
F, Q = jnp.array([[0.9]]), jnp.array([[0.25]])
H, R = jnp.array([[1.0]]), jnp.array([[1.0]])

chol_P0 = jnp.linalg.cholesky(P0)
chol_Q = jnp.linalg.cholesky(Q)


def initial_sampler(key, n):
    return m0 + jr.normal(key, (n, 1)) @ chol_P0.T


def transition_sampler(key, state):
    mean = (F @ state[:, None]).squeeze(-1)
    return mean + jr.normal(key, (1,)) @ chol_Q.T


def log_observation_fn(emission, state):
    mean = (H @ state[:, None]).squeeze(-1)
    return jstats.multivariate_normal.logpdf(emission, mean, R)


# Simulate some data
key = jr.PRNGKey(0)
T = 100
emissions = jr.normal(key, (T, 1))

# Run the bootstrap particle filter
posterior = bootstrap_filter(
    key=jr.PRNGKey(1),
    initial_sampler=initial_sampler,
    transition_sampler=transition_sampler,
    log_observation_fn=log_observation_fn,
    emissions=emissions,
    num_particles=1_000,
)

print(f"Log marginal likelihood: {posterior.marginal_loglik:.2f}")
print(f"Particles shape: {posterior.filtered_particles.shape}")
print(f"Mean ESS: {posterior.ess.mean():.1f}")

# Diagnostics
means = weighted_mean(posterior)
increments = log_ml_increments(posterior)

Development

A Makefile collects the common development tasks:

make test        # lint + pytest
make lint        # ruff check, format check, license headers, ty
make format      # add license headers, ruff format, ruff fix
make license     # add missing license headers
make docs        # build documentation
make serve-docs  # serve documentation locally
make install     # uv sync
make clean       # git clean (preserves .venv)

How releases work

Releases are fully automated. When a commit lands on main and CI passes, python-semantic-release inspects the commit history to determine whether a version bump is warranted:

  • fix: ... produces a patch release
  • feat: ... produces a minor release
  • A BREAKING CHANGE footer or ! suffix produces a major release

License

Apache-2.0. See LICENSE for the full text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smcjax-1.1.0.tar.gz (19.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

smcjax-1.1.0-py3-none-any.whl (25.6 kB view details)

Uploaded Python 3

File details

Details for the file smcjax-1.1.0.tar.gz.

File metadata

  • Download URL: smcjax-1.1.0.tar.gz
  • Upload date:
  • Size: 19.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for smcjax-1.1.0.tar.gz
Algorithm Hash digest
SHA256 c24a711c9acac0b24d0c2f16f3ed44ac745f2463ecb9ecc8eb3cddee7da925ad
MD5 1d663530ebf42dca2ebb4d66e26a2ee0
BLAKE2b-256 52e45425262b2b44b3f4775eacf4952650f0db5bd3192616ff2ffe9d52918d68

See more details on using hashes here.

Provenance

The following attestation bundles were made for smcjax-1.1.0.tar.gz:

Publisher: release.yml on michaelellis003/smcjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file smcjax-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: smcjax-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 25.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for smcjax-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ea2314dd6b2836bfb40df32956d09001be9a480aff1f9a98ecb552eecbb33f88
MD5 daa10abb870d75173bc69dc0bac3aa62
BLAKE2b-256 277873574d6f2ca3fb1c299f19b90c3a524650910a3e55d73027befda8756918

See more details on using hashes here.

Provenance

The following attestation bundles were made for smcjax-1.1.0-py3-none-any.whl:

Publisher: release.yml on michaelellis003/smcjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page