Extended State Spapce Model in JAX
Project description
Extended State Space Models in JAX
Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.
Support for Python 3.10+.
Example
All you need to do is define the transition and observation functions, and the initial state prior. These are all in
terms of MultivariateNormalLinearOperator
distributions from tensorflow_probability
.
import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
from essm_jax.essm import ExtendedStateSpaceModel
tfpd = tfp.distributions
def transition_fn(z, t):
mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
cov = 0.1 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
def observation_fn(z, t):
mean = z
cov = t * 0.01 * jnp.eye(np.size(z))
return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))
n = 1
initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))
essm = ExtendedStateSpaceModel(
transition_fn=transition_fn,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
materialise_jacobians=False, # Fast
more_data_than_params=False # if observation is bigger than latent we can speed it up.
)
T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)
# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0
# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)
# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)
# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)
# Forward simulate the model
forward_samples = essm.forward_simulate(
key=jax.random.PRNGKey(0),
num_time=25,
observations=samples.observation,
mask=mask
)
import pylab as plt
plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()
plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()
Change Log
13 August 2024: Initial release 1.0.0.
Star History
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
essm_jax-1.0.0.tar.gz
(24.7 kB
view details)
Built Distribution
essm_jax-1.0.0-py3-none-any.whl
(25.3 kB
view details)
File details
Details for the file essm_jax-1.0.0.tar.gz
.
File metadata
- Download URL: essm_jax-1.0.0.tar.gz
- Upload date:
- Size: 24.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7e4a9ff4d53edd6c346bcee5f350274cb28def6bf0843bdcd477462f6b1031ec |
|
MD5 | 46d4df9057e4827d6969db0e8c8f83d9 |
|
BLAKE2b-256 | 4c76543c6c94768523765ad5d504397546090428a7f0fe5328f6ebbd99e985dc |
File details
Details for the file essm_jax-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: essm_jax-1.0.0-py3-none-any.whl
- Upload date:
- Size: 25.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b2c17501949bd8007c6f1a236a44de779986c45701fdfc7de7bbe61787950a1b |
|
MD5 | d71771b14e01c2f96d7f8c2c8beba4a7 |
|
BLAKE2b-256 | d2bf6da8f1ce018949b0cdef51e9a9d650833a32c758472ceba4bb88c120af0b |