Extended State Space Modelling in JAX
Project description
Extended State Space Models in JAX
Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.
Support for Python 3.10+.
Example
All you need to do is define the transition and observation functions, and the initial state prior. These are all in
terms of MultivariateNormalLinearOperator
distributions from tensorflow_probability
.
import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
from essm_jax.essm import ExtendedStateSpaceModel
tfpd = tfp.distributions
def transition_fn(z, t, t_next, *args):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
def observation_fn(z, t, *args):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
n = 1
initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))
essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False, # Fast
more_data_than_params=False # if observation is bigger than latent we can speed it up.
)
T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)
# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0
# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)
# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)
# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)
# Forward simulate the model
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
filter_result=filter_result
)
import pylab as plt
plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()
plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()
Online Filtering
Take a look at examples to learn how to do online filtering, for interactive application.
Change Log
13 August 2024: Initial release 1.0.0. 14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.
Star History
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file essm_jax-1.0.2.tar.gz
.
File metadata
- Download URL: essm_jax-1.0.2.tar.gz
- Upload date:
- Size: 29.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 34056ec2f652c33b937b6e12af94daed469db8d7def0a3ada5db2126ddf0dc2a |
|
MD5 | d9625e567812be2d4dcc5a4d046889c6 |
|
BLAKE2b-256 | 1b7cced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c |
File details
Details for the file essm_jax-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: essm_jax-1.0.2-py3-none-any.whl
- Upload date:
- Size: 30.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 94396db655800c26031376dc75e14881123cf626da74d68850093f6d0073cce6 |
|
MD5 | 0c266dd334472249d29ea412d0056799 |
|
BLAKE2b-256 | 70a9b8a526540c76a74875acb0b5fb8aaa7c018ec1813176ebb1010588da5893 |