Skip to main content

Extended State Spapce Model in JAX

Project description

Extended State Space Models in JAX

Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.

Support for Python 3.10+.

Example

All you need to do is define the transition and observation functions, and the initial state prior. These are all in terms of MultivariateNormalLinearOperator distributions from tensorflow_probability.

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp

from essm_jax.essm import ExtendedStateSpaceModel

tfpd = tfp.distributions


def transition_fn(z, t, t_next):
    mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
    cov = 0.1 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


def observation_fn(z, t):
    mean = z
    cov = t * 0.01 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


n = 1

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
    transition_fn=transition_fn,
    observation_fn=observation_fn,
    initial_state_prior=initial_state_prior,
    materialise_jacobians=False,  # Fast
    more_data_than_params=False  # if observation is bigger than latent we can speed it up.
)

T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)

# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)

# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)

# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)

# Forward simulate the model
forward_samples = essm.forward_simulate(
    key=jax.random.PRNGKey(0),
    num_time=25,
    filter_result=filter_result
)

import pylab as plt

plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()

plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()

Online Filtering

Take a look at examples to learn how to do online filtering, for interactive application.

Change Log

13 August 2024: Initial release 1.0.0. 14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.

Star History

Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

essm_jax-1.0.1.tar.gz (27.5 kB view details)

Uploaded Source

Built Distribution

essm_jax-1.0.1-py3-none-any.whl (28.6 kB view details)

Uploaded Python 3

File details

Details for the file essm_jax-1.0.1.tar.gz.

File metadata

  • Download URL: essm_jax-1.0.1.tar.gz
  • Upload date:
  • Size: 27.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for essm_jax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 8ba75d77f0bd47befb632c36cda86454708ff0d55053bc9e2bb6d10fc81378cf
MD5 9dfac79748ed8313357ef5b91f807f1d
BLAKE2b-256 f268d46c37213b6176c8cba863d293c11c0c67a99fe35d43983e01170617f289

See more details on using hashes here.

File details

Details for the file essm_jax-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: essm_jax-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 28.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for essm_jax-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c247876355f1719cb6838a556fdd7e946efeeeb1a6d7b3ec583aaec94b712905
MD5 922088490803ebed9b5d96f367505306
BLAKE2b-256 e633b493c0d4c0fd4eee90300fca94ec95e0d7719b2f548605afa5d74199d752

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page