Sequential Monte Carlo and particle filtering in JAX
Project description
smcjax
Sequential Monte Carlo and particle filtering in JAX.
smcjax extends Dynamax and
BlackJAX with particle
filters and Bayesian workflow diagnostics that neither library provides.
All filters are JIT-compiled via jax.lax.scan and GPU-ready.
Features
- Bootstrap (SIR) particle filter — Gordon et al. (1993)
- Auxiliary particle filter — Pitt & Shephard (1999)
- Liu-West filter — joint state-parameter estimation via kernel density smoothing (Liu & West, 2001)
- Forward simulation — generate trajectories from state-space models
- Diagnostics suite — weighted mean/variance/quantiles, parameter summaries, ESS traces, particle diversity, per-step log evidence increments, replicated log-ML, log Bayes factors, CRPS
- 4 resampling schemes (via BlackJAX): systematic, stratified, multinomial, residual
- Conditional resampling with configurable ESS threshold
- All functions are
jit- andvmap-compatible - Type annotations via jaxtyping
Installation
pip install smcjax
Or from source:
git clone https://github.com/michaelellis003/smcjax.git
cd smcjax
uv sync
Quick Example
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as jstats
from smcjax import bootstrap_filter, weighted_mean, log_ml_increments
# Define a 1-D linear Gaussian state space model
m0, P0 = jnp.array([0.0]), jnp.array([[1.0]])
F, Q = jnp.array([[0.9]]), jnp.array([[0.25]])
H, R = jnp.array([[1.0]]), jnp.array([[1.0]])
chol_P0 = jnp.linalg.cholesky(P0)
chol_Q = jnp.linalg.cholesky(Q)
def initial_sampler(key, n):
return m0 + jr.normal(key, (n, 1)) @ chol_P0.T
def transition_sampler(key, state):
mean = (F @ state[:, None]).squeeze(-1)
return mean + jr.normal(key, (1,)) @ chol_Q.T
def log_observation_fn(emission, state):
mean = (H @ state[:, None]).squeeze(-1)
return jstats.multivariate_normal.logpdf(emission, mean, R)
# Simulate some data
key = jr.PRNGKey(0)
T = 100
emissions = jr.normal(key, (T, 1))
# Run the bootstrap particle filter
posterior = bootstrap_filter(
key=jr.PRNGKey(1),
initial_sampler=initial_sampler,
transition_sampler=transition_sampler,
log_observation_fn=log_observation_fn,
emissions=emissions,
num_particles=1_000,
)
print(f"Log marginal likelihood: {posterior.marginal_loglik:.2f}")
print(f"Particles shape: {posterior.filtered_particles.shape}")
print(f"Mean ESS: {posterior.ess.mean():.1f}")
# Diagnostics
means = weighted_mean(posterior)
increments = log_ml_increments(posterior)
Architecture
smcjax/
__init__.py # Public API (re-exports BlackJAX ESS & resampling)
types.py # PRNGKeyT, Scalar (matches Dynamax)
containers.py # ParticleState, ParticleFilterPosterior, LiuWestPosterior
weights.py # log_normalize, normalize
bootstrap.py # Bootstrap (SIR) particle filter
auxiliary.py # Auxiliary particle filter (Pitt & Shephard 1999)
liu_west.py # Liu-West filter for joint state-parameter estimation
simulate.py # Forward simulation from state-space models
diagnostics.py # Posterior summaries, model comparison, scoring rules
ESS and resampling (systematic, stratified, multinomial, residual) are
provided by BlackJAX and
re-exported from smcjax for convenience.
Cross-Validation
All filters are tested against reference libraries:
| Module | Reference | Validation |
|---|---|---|
bootstrap |
Dynamax Kalman filter | Log-ML within 5% of exact |
bootstrap |
particles (Chopin) | Log-ML within 3 nats |
auxiliary |
Dynamax Kalman filter | Log-ML within 5% of exact |
auxiliary |
Bootstrap (flat auxiliary = bootstrap) | Log-ML within 3 nats |
liu_west |
Auxiliary filter (fixed params) | Log-ML within 5 nats |
Notebooks
The notebooks/ directory contains a thesis-style Bayesian workflow
reproduction using a Hidden Markov Model with unknown parameters,
demonstrating the full pipeline: simulation, Liu-West filtering, parameter
recovery, model comparison via log Bayes factors, and CRPS evaluation.
Roadmap
| Phase | What | Status |
|---|---|---|
| 1 | Bootstrap particle filter | Done |
| 2 | Auxiliary particle filter | Done |
| 3 | Forward simulation + diagnostics | Done |
| 4 | Liu-West filter + model comparison | Done |
| 5 | EKF/UKF proposal particle filters | Planned |
| 6 | Particle MCMC (PMMH) | Planned |
Development
uv sync # Install all deps
uv run pre-commit install # Set up pre-commit hooks
uv run pytest -v --cov # Run tests with coverage
uv run ruff check . --fix # Lint
uv run pyright # Type check
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file smcjax-1.0.0.tar.gz.
File metadata
- Download URL: smcjax-1.0.0.tar.gz
- Upload date:
- Size: 258.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
98d290f59d43c1214a2a076d4799cb051778f9a5553a844038b237b6549e49af
|
|
| MD5 |
ff676c0fea4c14ce7b2673a5a2919fa1
|
|
| BLAKE2b-256 |
dee61ca3554019733577525ef716eca59cbd03c38a60a61ae09d2ff594944900
|
Provenance
The following attestation bundles were made for smcjax-1.0.0.tar.gz:
Publisher:
release.yml on michaelellis003/smcjax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
smcjax-1.0.0.tar.gz -
Subject digest:
98d290f59d43c1214a2a076d4799cb051778f9a5553a844038b237b6549e49af - Sigstore transparency entry: 1007828691
- Sigstore integration time:
-
Permalink:
michaelellis003/smcjax@bb8cb55ae44d1d7bb1fed1a36df53b2cf77f0e3a -
Branch / Tag:
refs/heads/main - Owner: https://github.com/michaelellis003
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@bb8cb55ae44d1d7bb1fed1a36df53b2cf77f0e3a -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file smcjax-1.0.0-py3-none-any.whl.
File metadata
- Download URL: smcjax-1.0.0-py3-none-any.whl
- Upload date:
- Size: 24.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a1927bee1d1fbffe8ad0e322fb960abad81852b347716f6a7b215bbf3468548
|
|
| MD5 |
a4593ee95000a53ffc88172658bfc27a
|
|
| BLAKE2b-256 |
0b8e6d9803fdc1f424c70169e7395c968600e5496fddf85ff992e4f822ed982d
|
Provenance
The following attestation bundles were made for smcjax-1.0.0-py3-none-any.whl:
Publisher:
release.yml on michaelellis003/smcjax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
smcjax-1.0.0-py3-none-any.whl -
Subject digest:
0a1927bee1d1fbffe8ad0e322fb960abad81852b347716f6a7b215bbf3468548 - Sigstore transparency entry: 1007828699
- Sigstore integration time:
-
Permalink:
michaelellis003/smcjax@bb8cb55ae44d1d7bb1fed1a36df53b2cf77f0e3a -
Branch / Tag:
refs/heads/main - Owner: https://github.com/michaelellis003
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@bb8cb55ae44d1d7bb1fed1a36df53b2cf77f0e3a -
Trigger Event:
workflow_dispatch
-
Statement type: