Skip to main content

Extended State Space Modelling in JAX

Project description

Extended State Space Models in JAX

Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.

Support for Python 3.10+.

Example

All you need to do is define the transition and observation functions, and the initial state prior. These are all in terms of MultivariateNormalLinearOperator distributions from tensorflow_probability.

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp

from essm_jax.essm import ExtendedStateSpaceModel

tfpd = tfp.distributions


def transition_fn(z, t, t_next, *args):
    mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
    cov = 0.1 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


def observation_fn(z, t, *args):
    mean = z
    cov = t * 0.01 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


n = 1

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
    transition_fn=transition_fn,
    observation_fn=observation_fn,
    initial_state_prior=initial_state_prior,
    materialise_jacobians=False,  # Fast
    more_data_than_params=False  # if observation is bigger than latent we can speed it up.
)

T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)

# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)

# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)

# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)

# Forward simulate the model
forward_samples = essm.forward_simulate(
    key=jax.random.PRNGKey(0),
    num_time=25,
    filter_result=filter_result
)

import pylab as plt

plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()

plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()

Online Filtering

Take a look at examples to learn how to do online filtering, for interactive application.

Change Log

13 August 2024: Initial release 1.0.0. 14 August 2024: 1.0.1 released. Added sparse util. Add incremental API for online filtering. Arbitrary dt.

Star History

Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

essm_jax-1.0.2.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

essm_jax-1.0.2-py3-none-any.whl (30.5 kB view details)

Uploaded Python 3

File details

Details for the file essm_jax-1.0.2.tar.gz.

File metadata

  • Download URL: essm_jax-1.0.2.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.10

File hashes

Hashes for essm_jax-1.0.2.tar.gz
Algorithm Hash digest
SHA256 34056ec2f652c33b937b6e12af94daed469db8d7def0a3ada5db2126ddf0dc2a
MD5 d9625e567812be2d4dcc5a4d046889c6
BLAKE2b-256 1b7cced965cf174bd9e926f8fc6ce53e0af37a59b57bbb401849dd789a0b763c

See more details on using hashes here.

File details

Details for the file essm_jax-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: essm_jax-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 30.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.10

File hashes

Hashes for essm_jax-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 94396db655800c26031376dc75e14881123cf626da74d68850093f6d0073cce6
MD5 0c266dd334472249d29ea412d0056799
BLAKE2b-256 70a9b8a526540c76a74875acb0b5fb8aaa7c018ec1813176ebb1010588da5893

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page