Skip to main content

Extended State Spapce Model in JAX

Project description

Extended State Space Models in JAX

Given a potentially non-linear state space model this allows you to solve the forward and backward inference steps, for linear space space models this is equivalent to the Kalman and Rauch-Tung-Striebel recursions.

Support for Python 3.10+.

Example

All you need to do is define the transition and observation functions, and the initial state prior. These are all in terms of MultivariateNormalLinearOperator distributions from tensorflow_probability.

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp

from essm_jax.essm import ExtendedStateSpaceModel

tfpd = tfp.distributions


def transition_fn(z, t):
    mean = z + jnp.sin(2 * jnp.pi * t / 10 * z)
    cov = 0.1 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


def observation_fn(z, t):
    mean = z
    cov = t * 0.01 * jnp.eye(np.size(z))
    return tfpd.MultivariateNormalTriL(mean, jnp.linalg.cholesky(cov))


n = 1

initial_state_prior = tfpd.MultivariateNormalTriL(jnp.zeros(n), jnp.eye(n))

essm = ExtendedStateSpaceModel(
    transition_fn=transition_fn,
    observation_fn=observation_fn,
    initial_state_prior=initial_state_prior,
    materialise_jacobians=False,  # Fast
    more_data_than_params=False  # if observation is bigger than latent we can speed it up.
)

T = 100
samples = essm.sample(jax.random.PRNGKey(0), num_time=T)

# Suppose we only observe every 3rd observation
mask = jnp.arange(T) % 3 != 0

# Marginal likelihood, p(x[:]) = prod_t p(x[t] | x[:t-1])
log_prob = essm.log_prob(samples.observation, mask=mask)
print(log_prob)

# Filtered latent distribution, p(z[t] | x[:t])
filter_result = essm.forward_filter(samples.observation, mask=mask)

# Smoothed latent distribution, p(z[t] | x[:]), i.e. past latents given all future observations
# Including new estimate for prior state p(z[0])
smooth_result, posterior_prior = essm.backward_smooth(filter_result, include_prior=True)
print(smooth_result)

# Forward simulate the model
forward_samples = essm.forward_simulate(
    key=jax.random.PRNGKey(0),
    num_time=25,
    observations=samples.observation,
    mask=mask
)

import pylab as plt

plt.plot(samples.t, samples.latent[:, 0], label='latent')
plt.plot(filter_result.t, filter_result.filtered_mean[:, 0], label='filtered latent')
plt.plot(forward_samples.t, forward_samples.latent[:, 0], label='forward_simulated latent')
plt.legend()
plt.show()

plt.plot(samples.t, samples.observation[:, 0], label='observation')
plt.plot(filter_result.t, filter_result.observation_mean[:, 0], label='filtered obs')
plt.plot(forward_samples.t, forward_samples.observation[:, 0], label='forward_simulated obs')
plt.legend()
plt.show()

Change Log

13 August 2024: Initial release 1.0.0.

Star History

Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

essm_jax-1.0.0.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

essm_jax-1.0.0-py3-none-any.whl (25.3 kB view details)

Uploaded Python 3

File details

Details for the file essm_jax-1.0.0.tar.gz.

File metadata

  • Download URL: essm_jax-1.0.0.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for essm_jax-1.0.0.tar.gz
Algorithm Hash digest
SHA256 7e4a9ff4d53edd6c346bcee5f350274cb28def6bf0843bdcd477462f6b1031ec
MD5 46d4df9057e4827d6969db0e8c8f83d9
BLAKE2b-256 4c76543c6c94768523765ad5d504397546090428a7f0fe5328f6ebbd99e985dc

See more details on using hashes here.

File details

Details for the file essm_jax-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: essm_jax-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 25.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.9

File hashes

Hashes for essm_jax-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b2c17501949bd8007c6f1a236a44de779986c45701fdfc7de7bbe61787950a1b
MD5 d71771b14e01c2f96d7f8c2c8beba4a7
BLAKE2b-256 d2bf6da8f1ce018949b0cdef51e9a9d650833a32c758472ceba4bb88c120af0b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page