Skip to main content

Inference in LTI-SDEs using CRFs

Project description

linsdex

linsdex is a high performance JAX-based library for linear stochastic differential equations (SDEs), state-space models, and Gaussian inference. It provides a modular and extensible framework for defining, simulating, and conditioning linear-Gaussian systems with support for parallelized inference on GPUs.

Agent Skills for Cursor IDE

This project includes Agent Skills to help you get started quickly when using Cursor. Skills provide domain-specific guidance that the AI agent can use to help you with common tasks.

To invoke a skill, type / in the Cursor chat and search for the skill name, or the agent will automatically apply relevant skills based on context.

Skill When to Use
/linsdex Overview and quick start guide for the entire library
/sde-conditioning Time series interpolation, Brownian bridges, posterior sampling
/diffusion-conversions Converting between y1, score, flow, drift for diffusion models
/probability-paths Bridge path marginals, memoryless sampling, computing p(x_t given y_1)
/crf-inference Discrete-time Gaussian CRF inference, marginals, sampling
/gaussian-distributions Working with three Gaussian parameterizations for numerical stability
/matrix-operations Using specialized matrix types with symbolic tags

Quick Start

Define an SDE, condition it on a starting point, and sample trajectories in parallel using JAX.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from linsdex import BrownianMotion

# 1. Define a 2D Brownian Motion SDE
sde = BrownianMotion(sigma=1.0, dim=2)

# 2. Condition on starting at the origin at t=0
conditioned_sde = sde.condition_on_starting_point(t0=0.0, x0=jnp.zeros(2))

# 3. Sample 100 trajectories in parallel on a time grid
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 100)
times = jnp.linspace(0.0, 1.0, 500)

# Use jax.vmap for efficient batch sampling
trajectories = jax.vmap(conditioned_sde.sample, in_axes=(0, None))(keys, times)

# 4. Plot the trajectories
plt.figure(figsize=(8, 6))
for i in range(10):
  plt.plot(trajectories.values[i, :, 0], trajectories.values[i, :, 1], alpha=0.6)
plt.xlabel("x1")
plt.ylabel("x2")
plt.savefig("quick_start_simple.png")

Key Features

The library focuses on high performance and numerical stability.

  • Linear SDEs: Comprehensive support for linear time-invariant (LTI) and time-varying SDEs with exact transition distributions.
  • Efficient Inference: Sequential and parallel message passing (parallel scan) for filtering, smoothing, and sampling in chain-structured Gaussian CRFs.
  • Probabilistic Primitives: Multiple Gaussian parameterizations (Standard, Natural, Mixed) with numerically stable operations.
  • Specialized Linear Algebra: A custom matrix library with Diagonal, Block, and Dense types that leverage symbolic tags for optimization.
  • Diffusion Utilities: Unified interface for mapping between clean data predictions, scores, and probability flow for generative modeling.
  • JAX-Native: Fully compatible with jax.vmap, jax.grad, and jax.jit for automatic vectorization and differentiation.

Stochastic Harmonic Oscillator

This example demonstrates how to perform inference in a State Space Model where the latent dynamics are governed by a Linear Stochastic Differential Equation. By conditioning a base SDE on noisy observations, linsdex constructs a Gaussian Conditional Random Field. We then use parallel message passing to compute the posterior $p(x_{1:T} | y_{1:T})$, which allows us to draw samples that interpolate the data while respecting the underlying physical laws of the oscillator.

The latent state $x_t \in \mathbb{R}^2$ represents the position and velocity of a stochastic harmonic oscillator. Its evolution is described by the linear SDE

$$ dx_t = F x_t dt + L dW_t $$

where $F$ defines the deterministic drift and $L$ defines the diffusion. For the harmonic oscillator,

$$ F = \begin{bmatrix} 0 & 1 \\ -\omega^2 & -\gamma \end{bmatrix} $$

The library automatically computes the exact transition distribution $p(x_t | x_{t-1})$ by solving the Lyapunov equation over the time interval $\Delta t$.

We observe the position $y_t \in \mathbb{R}^1$ through a noisy channel

$$ p(y_t | x_t) = \mathcal{N}(y_t, H x_t, \sigma^2 I) $$

In our example, the encoder PaddingLatentVariableEncoderWithPrior specifies $H = [1, 0]$. This means we only observe the first component of the latent state with Gaussian noise $\sigma$.

The model can be represented as a chain structured graphical model where $x_t$ are latent variables and $y_t$ are noisy observations.

SSM Plate Notation

import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from linsdex import TimeSeries, LinearTimeInvariantSDE, DenseMatrix, DiagonalMatrix
from linsdex.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior

# 1. Define time series data
# We create sparse 1D data for interpolation
times = jnp.linspace(0, 10, 5)
values = jnp.sin(times)[:, None]
series = TimeSeries(times, values)

# 2. Define a linear SDE by specifying F and L
# Harmonic oscillator with state x = [position, velocity]
freq, coeff, sigma = 1.0, 0.1, 0.5
F = DenseMatrix(jnp.array([[0, 1], [-freq**2, -coeff]]))
L = DiagonalMatrix(jnp.array([0, sigma]))
sde = LinearTimeInvariantSDE(F=F, L=L)

# 3. Create potentials from data and condition the SDE
# PaddingLatentVariableEncoderWithPrior pads the 1D observations to the 2D latent space
encoder = PaddingLatentVariableEncoderWithPrior(
    y_dim=1,
    x_dim=2,
    sigma=0.01
)
potentials = encoder(series)
conditioned_sde = sde.condition_on(potentials)

# 4. Draw samples from the posterior
key = random.PRNGKey(0)
keys = random.split(key, 128)

# Interpolate on a denser time grid
save_times = jnp.linspace(0, 10, 2000)
samples: TimeSeries = jax.vmap(conditioned_sde.sample, in_axes=(0, None))(keys, save_times)

# 5. Plot the original time series and the posterior samples
fig, axes = samples.plot(show_plot=False)
plt.savefig("quick_start.png")

Core Components

The library provides several layers of abstraction for probabilistic modeling.

Stochastic Differential Equations

The library defines a hierarchy of SDEs starting from AbstractSDE.

Linear Time-Invariant (LTI) SDEs are models where the drift and diffusion coefficients are constant over time. Examples include BrownianMotion, OrnsteinUhlenbeck, and StochasticHarmonicOscillator.

Linear SDEs are models with time-varying coefficients, such as VariancePreserving SDEs used in diffusion models.

Conditioned SDEs allow for conditioning a process on any number of Gaussian potentials using ConditionedLinearSDE.

Diffusion Conversions

linsdex provides a unified interface for working with different mathematical representations of a diffusion process. They allow for mapping between neural network predictions (such as the clean data $y_1$) and quantities required for sampling (such as the probability flow or the drift of an SDE).

The DiffusionModelComponents class encapsulates the objects that define a diffusion process, including the base linear SDE, the prior distribution at $t_0$, and the evidence covariance at $t_1$.

from jaxtyping import Array, Float
from linsdex.diffusion_model.probability_path import DiffusionModelComponents
from linsdex import BrownianMotion, StandardGaussian, DiagonalMatrix

dim: int = 10
sde = BrownianMotion(sigma=0.1, dim=dim)
xt0_prior = StandardGaussian(jnp.zeros(dim), DiagonalMatrix.eye(dim))
evidence_cov = DiagonalMatrix.eye(dim) * 0.001

components = DiffusionModelComponents(
    linear_sde=sde,
    t0=0.0,
    x_t0_prior=xt0_prior,
    t1=1.0,
    evidence_cov=evidence_cov
)

The DiffusionModelConversions class maps between different parameterizations of the diffusion path, such as converting a prediction into the score function, probability flow, or drift.

from jaxtyping import Array, Float
from linsdex.diffusion_model.probability_path import DiffusionModelConversions

conversions = DiffusionModelConversions(components, t=0.5)

# Map clean data prediction y1 to different sampling quantities
flow: Float[Array, "dim"] = conversions.y1_to_flow(y1_pred, xt)
drift: Float[Array, "dim"] = conversions.y1_to_drift(y1_pred, xt)
score: Float[Array, "dim"] = conversions.y1_to_score(xt, y1_pred)

The ProbabilityPathSlice class can be used to compute and cache time-dependent intermediate quantities, avoiding redundant computations when performing multiple conversions at the same time step. Additionally, noise_schedule_drift_correction allows for adjusting the drift when a different noise schedule is used at inference time compared to training.

Gaussian Potentials

linsdex implements Gaussians in three forms to ensure stability across different operations.

StandardGaussian uses mean ($\mu$) and covariance ($\Sigma$) parameters. This form is best for sampling and interpreting results.

NaturalGaussian uses precision-mean ($h$) and precision ($J$) parameters. This form is best for multiplying densities and message passing.

MixedGaussian uses mean ($\mu$) and precision ($J$) parameters. It provides a stable bridge between standard and natural forms, which is particularly useful for Kalman filtering steps.

Conditional Random Fields (CRF)

The CRF class represents a chain-structured probabilistic model. It serves as the engine for discrete-time inference.

from jaxtyping import Array, Float, PRNGKeyArray
from linsdex import CRF, AbstractPotential

# Create a CRF from node potentials and transitions
crf = CRF(node_potentials, transitions)

# Perform inference
messages: AbstractPotential = crf.get_forward_messages() # Forward pass
marginals: AbstractPotential = crf.get_marginals() # p(x_t | observations)
samples: Float[Array, "T D"] = crf.sample(key) # Draw joint samples

For long sequences, linsdex uses a parallel scan implementation of message passing to provide $O(\log T)$ complexity on parallel hardware.

Specialized Matrix Library

To handle structured models efficiently, linsdex includes a matrix library that avoids expensive dense operations when possible.

DiagonalMatrix is used for decoupled systems or independent noise.

Block2x2Matrix and Block3x3Matrix are optimized for higher-order tracking models.

Matrices carry symbolic tags like TAGS.zero_tags and TAGS.eye_tags. These tags allow the library to symbolically simplify expressions like $0 \times A$ or $I \times B$ before they reach JAX.

Installation

pip install .

Citation

If you use linsdex in your research, please cite the following software.

@software{cunningham2025linsdex,
  author       = {Cunningham, Edmond},
  title        = {{linsdex}: A High-Performance JAX-based Library for Linear Stochastic Differential Equations},
  version      = {0.1.0},
  url          = {https://github.com/EddieCunningham/linsdex},
  note         = {Python package},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

linsdex-0.1.0.tar.gz (237.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

linsdex-0.1.0-py3-none-any.whl (212.5 kB view details)

Uploaded Python 3

File details

Details for the file linsdex-0.1.0.tar.gz.

File metadata

  • Download URL: linsdex-0.1.0.tar.gz
  • Upload date:
  • Size: 237.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for linsdex-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0dd70890b05d8aa917aa2f87cb80c0c86465117bf1aefba61b38bbed5c23a029
MD5 fcdf7915a9b9d56b3faa4c55d7e87a66
BLAKE2b-256 8cf0d6e518347212c2f2e527c0153f36f8c2848c8fbcbe73f25de45c3bd69a3a

See more details on using hashes here.

File details

Details for the file linsdex-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: linsdex-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 212.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for linsdex-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f2acdef3c122594b2916c921c7c88b76ce3674d3e4beccca18e1cd0313a34ffb
MD5 d0c5d8e340cf4bc0f0486291ffbb2218
BLAKE2b-256 92a1ff834f17f192cd7a3b8e5e31e39811eaee3d8572de261a225007190a0aee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page