Sequential Monte Carlo and particle filtering in JAX
Project description
smcjax
Sequential Monte Carlo and particle filtering in JAX.
smcjax is a JAX implementation of the methods developed in my
master's thesis
on sequential inference for Hidden Markov Models (University of
Arkansas, 2018). It extends Dynamax
and BlackJAX with particle
filters and Bayesian workflow diagnostics. All filters are JIT-compiled via
jax.lax.scan and GPU-ready.
Features
- Bootstrap (SIR) particle filter — Gordon et al. (1993)
- Auxiliary particle filter — Pitt & Shephard (1999)
- Liu-West filter — joint state-parameter estimation via kernel density smoothing (Liu & West, 2001)
- Forward simulation — generate trajectories from state-space models
- Diagnostics — weighted mean/variance/quantiles, parameter summaries, ESS traces, particle diversity, per-step log evidence increments, replicated log-ML, log Bayes factors, CRPS
- 4 resampling schemes (via BlackJAX): systematic, stratified, multinomial, residual
- Conditional resampling with configurable ESS threshold
- All functions are
jit- andvmap-compatible - Type annotations via jaxtyping
Requirements
- Python 3.10 or later
- uv installed
Installation
pip install smcjax
Or from source:
git clone https://github.com/michaelellis003/smcjax.git
cd smcjax
uv sync
Install the pre-commit hooks (one-time setup):
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
uv run pre-commit install --hook-type pre-push
Quick example
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as jstats
from smcjax import bootstrap_filter, weighted_mean, log_ml_increments
# Define a 1-D linear Gaussian state space model
m0, P0 = jnp.array([0.0]), jnp.array([[1.0]])
F, Q = jnp.array([[0.9]]), jnp.array([[0.25]])
H, R = jnp.array([[1.0]]), jnp.array([[1.0]])
chol_P0 = jnp.linalg.cholesky(P0)
chol_Q = jnp.linalg.cholesky(Q)
def initial_sampler(key, n):
return m0 + jr.normal(key, (n, 1)) @ chol_P0.T
def transition_sampler(key, state):
mean = (F @ state[:, None]).squeeze(-1)
return mean + jr.normal(key, (1,)) @ chol_Q.T
def log_observation_fn(emission, state):
mean = (H @ state[:, None]).squeeze(-1)
return jstats.multivariate_normal.logpdf(emission, mean, R)
# Simulate some data
key = jr.PRNGKey(0)
T = 100
emissions = jr.normal(key, (T, 1))
# Run the bootstrap particle filter
posterior = bootstrap_filter(
key=jr.PRNGKey(1),
initial_sampler=initial_sampler,
transition_sampler=transition_sampler,
log_observation_fn=log_observation_fn,
emissions=emissions,
num_particles=1_000,
)
print(f"Log marginal likelihood: {posterior.marginal_loglik:.2f}")
print(f"Particles shape: {posterior.filtered_particles.shape}")
print(f"Mean ESS: {posterior.ess.mean():.1f}")
# Diagnostics
means = weighted_mean(posterior)
increments = log_ml_increments(posterior)
Development
A Makefile collects the common development tasks:
make test # lint + pytest
make lint # ruff check, format check, license headers, ty
make format # add license headers, ruff format, ruff fix
make license # add missing license headers
make docs # build documentation
make serve-docs # serve documentation locally
make install # uv sync
make clean # git clean (preserves .venv)
How releases work
Releases are fully automated. When a commit lands on main and CI
passes, python-semantic-release inspects the commit history to
determine whether a version bump is warranted:
fix: ...produces a patch releasefeat: ...produces a minor release- A
BREAKING CHANGEfooter or!suffix produces a major release
License
Apache-2.0. See LICENSE for the full text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file smcjax-1.1.0.tar.gz.
File metadata
- Download URL: smcjax-1.1.0.tar.gz
- Upload date:
- Size: 19.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c24a711c9acac0b24d0c2f16f3ed44ac745f2463ecb9ecc8eb3cddee7da925ad
|
|
| MD5 |
1d663530ebf42dca2ebb4d66e26a2ee0
|
|
| BLAKE2b-256 |
52e45425262b2b44b3f4775eacf4952650f0db5bd3192616ff2ffe9d52918d68
|
Provenance
The following attestation bundles were made for smcjax-1.1.0.tar.gz:
Publisher:
release.yml on michaelellis003/smcjax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
smcjax-1.1.0.tar.gz -
Subject digest:
c24a711c9acac0b24d0c2f16f3ed44ac745f2463ecb9ecc8eb3cddee7da925ad - Sigstore transparency entry: 1271557263
- Sigstore integration time:
-
Permalink:
michaelellis003/smcjax@dec3d685b85499cd56e0775716a4549d6553be20 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/michaelellis003
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@dec3d685b85499cd56e0775716a4549d6553be20 -
Trigger Event:
workflow_run
-
Statement type:
File details
Details for the file smcjax-1.1.0-py3-none-any.whl.
File metadata
- Download URL: smcjax-1.1.0-py3-none-any.whl
- Upload date:
- Size: 25.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ea2314dd6b2836bfb40df32956d09001be9a480aff1f9a98ecb552eecbb33f88
|
|
| MD5 |
daa10abb870d75173bc69dc0bac3aa62
|
|
| BLAKE2b-256 |
277873574d6f2ca3fb1c299f19b90c3a524650910a3e55d73027befda8756918
|
Provenance
The following attestation bundles were made for smcjax-1.1.0-py3-none-any.whl:
Publisher:
release.yml on michaelellis003/smcjax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
smcjax-1.1.0-py3-none-any.whl -
Subject digest:
ea2314dd6b2836bfb40df32956d09001be9a480aff1f9a98ecb552eecbb33f88 - Sigstore transparency entry: 1271557354
- Sigstore integration time:
-
Permalink:
michaelellis003/smcjax@dec3d685b85499cd56e0775716a4549d6553be20 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/michaelellis003
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@dec3d685b85499cd56e0775716a4549d6553be20 -
Trigger Event:
workflow_run
-
Statement type: