Inference in LTI-SDEs using CRFs
Project description
linsdex
linsdex is a high performance JAX-based library for linear stochastic differential equations (SDEs), state-space models, and Gaussian inference. It provides a modular and extensible framework for defining, simulating, and conditioning linear-Gaussian systems with support for parallelized inference on GPUs.
Agent Skills for Cursor IDE
This project includes Agent Skills to help you get started quickly when using Cursor. Skills provide domain-specific guidance that the AI agent can use to help you with common tasks.
To invoke a skill, type / in the Cursor chat and search for the skill name, or the agent will automatically apply relevant skills based on context.
| Skill | When to Use |
|---|---|
/linsdex |
Overview and quick start guide for the entire library |
/sde-conditioning |
Time series interpolation, Brownian bridges, posterior sampling |
/diffusion-conversions |
Converting between y1, score, flow, drift for diffusion models |
/probability-paths |
Bridge path marginals, memoryless sampling, computing p(x_t given y_1) |
/crf-inference |
Discrete-time Gaussian CRF inference, marginals, sampling |
/gaussian-distributions |
Working with three Gaussian parameterizations for numerical stability |
/matrix-operations |
Using specialized matrix types with symbolic tags |
Quick Start
Define an SDE, condition it on a starting point, and sample trajectories in parallel using JAX.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from linsdex import BrownianMotion
# 1. Define a 2D Brownian Motion SDE
sde = BrownianMotion(sigma=1.0, dim=2)
# 2. Condition on starting at the origin at t=0
conditioned_sde = sde.condition_on_starting_point(t0=0.0, x0=jnp.zeros(2))
# 3. Sample 100 trajectories in parallel on a time grid
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 100)
times = jnp.linspace(0.0, 1.0, 500)
# Use jax.vmap for efficient batch sampling
trajectories = jax.vmap(conditioned_sde.sample, in_axes=(0, None))(keys, times)
# 4. Plot the trajectories
plt.figure(figsize=(8, 6))
for i in range(10):
plt.plot(trajectories.values[i, :, 0], trajectories.values[i, :, 1], alpha=0.6)
plt.xlabel("x1")
plt.ylabel("x2")
plt.savefig("quick_start_simple.png")
Key Features
The library focuses on high performance and numerical stability.
- Linear SDEs: Comprehensive support for linear time-invariant (LTI) and time-varying SDEs with exact transition distributions.
- Efficient Inference: Sequential and parallel message passing (parallel scan) for filtering, smoothing, and sampling in chain-structured Gaussian CRFs.
- Probabilistic Primitives: Multiple Gaussian parameterizations (Standard, Natural, Mixed) with numerically stable operations.
- Specialized Linear Algebra: A custom matrix library with Diagonal, Block, and Dense types that leverage symbolic tags for optimization.
- Diffusion Utilities: Unified interface for mapping between clean data predictions, scores, and probability flow for generative modeling.
- JAX-Native: Fully compatible with
jax.vmap,jax.grad, andjax.jitfor automatic vectorization and differentiation.
Stochastic Harmonic Oscillator
This example demonstrates how to perform inference in a State Space Model where the latent dynamics are governed by a Linear Stochastic Differential Equation. By conditioning a base SDE on noisy observations, linsdex constructs a Gaussian Conditional Random Field. We then use parallel message passing to compute the posterior $p(x_{1:T} | y_{1:T})$, which allows us to draw samples that interpolate the data while respecting the underlying physical laws of the oscillator.
The latent state $x_t \in \mathbb{R}^2$ represents the position and velocity of a stochastic harmonic oscillator. Its evolution is described by the linear SDE
$$ dx_t = F x_t dt + L dW_t $$
where $F$ defines the deterministic drift and $L$ defines the diffusion. For the harmonic oscillator,
$$ F = \begin{bmatrix} 0 & 1 \\ -\omega^2 & -\gamma \end{bmatrix} $$
The library automatically computes the exact transition distribution $p(x_t | x_{t-1})$ by solving the Lyapunov equation over the time interval $\Delta t$.
We observe the position $y_t \in \mathbb{R}^1$ through a noisy channel
$$ p(y_t | x_t) = \mathcal{N}(y_t, H x_t, \sigma^2 I) $$
In our example, the encoder PaddingLatentVariableEncoderWithPrior specifies $H = [1, 0]$. This means we only observe the first component of the latent state with Gaussian noise $\sigma$.
The model can be represented as a chain structured graphical model where $x_t$ are latent variables and $y_t$ are noisy observations.
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from linsdex import TimeSeries, LinearTimeInvariantSDE, DenseMatrix, DiagonalMatrix
from linsdex.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
# 1. Define time series data
# We create sparse 1D data for interpolation
times = jnp.linspace(0, 10, 5)
values = jnp.sin(times)[:, None]
series = TimeSeries(times, values)
# 2. Define a linear SDE by specifying F and L
# Harmonic oscillator with state x = [position, velocity]
freq, coeff, sigma = 1.0, 0.1, 0.5
F = DenseMatrix(jnp.array([[0, 1], [-freq**2, -coeff]]))
L = DiagonalMatrix(jnp.array([0, sigma]))
sde = LinearTimeInvariantSDE(F=F, L=L)
# 3. Create potentials from data and condition the SDE
# PaddingLatentVariableEncoderWithPrior pads the 1D observations to the 2D latent space
encoder = PaddingLatentVariableEncoderWithPrior(
y_dim=1,
x_dim=2,
sigma=0.01
)
potentials = encoder(series)
conditioned_sde = sde.condition_on(potentials)
# 4. Draw samples from the posterior
key = random.PRNGKey(0)
keys = random.split(key, 128)
# Interpolate on a denser time grid
save_times = jnp.linspace(0, 10, 2000)
samples: TimeSeries = jax.vmap(conditioned_sde.sample, in_axes=(0, None))(keys, save_times)
# 5. Plot the original time series and the posterior samples
fig, axes = samples.plot(show_plot=False)
plt.savefig("quick_start.png")
Core Components
The library provides several layers of abstraction for probabilistic modeling.
Stochastic Differential Equations
The library defines a hierarchy of SDEs starting from AbstractSDE.
Linear Time-Invariant (LTI) SDEs are models where the drift and diffusion coefficients are constant over time. Examples include BrownianMotion, OrnsteinUhlenbeck, and StochasticHarmonicOscillator.
Linear SDEs are models with time-varying coefficients, such as VariancePreserving SDEs used in diffusion models.
Conditioned SDEs allow for conditioning a process on any number of Gaussian potentials using ConditionedLinearSDE.
Diffusion Conversions
linsdex provides a unified interface for working with different mathematical representations of a diffusion process. They allow for mapping between neural network predictions (such as the clean data $y_1$) and quantities required for sampling (such as the probability flow or the drift of an SDE).
The DiffusionModelComponents class encapsulates the objects that define a diffusion process, including the base linear SDE, the prior distribution at $t_0$, and the evidence covariance at $t_1$.
from jaxtyping import Array, Float
from linsdex.diffusion_model.probability_path import DiffusionModelComponents
from linsdex import BrownianMotion, StandardGaussian, DiagonalMatrix
dim: int = 10
sde = BrownianMotion(sigma=0.1, dim=dim)
xt0_prior = StandardGaussian(jnp.zeros(dim), DiagonalMatrix.eye(dim))
evidence_cov = DiagonalMatrix.eye(dim) * 0.001
components = DiffusionModelComponents(
linear_sde=sde,
t0=0.0,
x_t0_prior=xt0_prior,
t1=1.0,
evidence_cov=evidence_cov
)
The DiffusionModelConversions class maps between different parameterizations of the diffusion path, such as converting a prediction into the score function, probability flow, or drift.
from jaxtyping import Array, Float
from linsdex.diffusion_model.probability_path import DiffusionModelConversions
conversions = DiffusionModelConversions(components, t=0.5)
# Map clean data prediction y1 to different sampling quantities
flow: Float[Array, "dim"] = conversions.y1_to_flow(y1_pred, xt)
drift: Float[Array, "dim"] = conversions.y1_to_drift(y1_pred, xt)
score: Float[Array, "dim"] = conversions.y1_to_score(xt, y1_pred)
The ProbabilityPathSlice class can be used to compute and cache time-dependent intermediate quantities, avoiding redundant computations when performing multiple conversions at the same time step. Additionally, noise_schedule_drift_correction allows for adjusting the drift when a different noise schedule is used at inference time compared to training.
Gaussian Potentials
linsdex implements Gaussians in three forms to ensure stability across different operations.
StandardGaussian uses mean ($\mu$) and covariance ($\Sigma$) parameters. This form is best for sampling and interpreting results.
NaturalGaussian uses precision-mean ($h$) and precision ($J$) parameters. This form is best for multiplying densities and message passing.
MixedGaussian uses mean ($\mu$) and precision ($J$) parameters. It provides a stable bridge between standard and natural forms, which is particularly useful for Kalman filtering steps.
Conditional Random Fields (CRF)
The CRF class represents a chain-structured probabilistic model. It serves as the engine for discrete-time inference.
from jaxtyping import Array, Float, PRNGKeyArray
from linsdex import CRF, AbstractPotential
# Create a CRF from node potentials and transitions
crf = CRF(node_potentials, transitions)
# Perform inference
messages: AbstractPotential = crf.get_forward_messages() # Forward pass
marginals: AbstractPotential = crf.get_marginals() # p(x_t | observations)
samples: Float[Array, "T D"] = crf.sample(key) # Draw joint samples
For long sequences, linsdex uses a parallel scan implementation of message passing to provide $O(\log T)$ complexity on parallel hardware.
Specialized Matrix Library
To handle structured models efficiently, linsdex includes a matrix library that avoids expensive dense operations when possible.
DiagonalMatrix is used for decoupled systems or independent noise.
Block2x2Matrix and Block3x3Matrix are optimized for higher-order tracking models.
Matrices carry symbolic tags like TAGS.zero_tags and TAGS.eye_tags. These tags allow the library to symbolically simplify expressions like $0 \times A$ or $I \times B$ before they reach JAX.
Installation
pip install .
Citation
If you use linsdex in your research, please cite the following software.
@software{cunningham2025linsdex,
author = {Cunningham, Edmond},
title = {{linsdex}: A High-Performance JAX-based Library for Linear Stochastic Differential Equations},
version = {0.1.0},
url = {https://github.com/EddieCunningham/linsdex},
note = {Python package},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file linsdex-0.1.0.tar.gz.
File metadata
- Download URL: linsdex-0.1.0.tar.gz
- Upload date:
- Size: 237.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0dd70890b05d8aa917aa2f87cb80c0c86465117bf1aefba61b38bbed5c23a029
|
|
| MD5 |
fcdf7915a9b9d56b3faa4c55d7e87a66
|
|
| BLAKE2b-256 |
8cf0d6e518347212c2f2e527c0153f36f8c2848c8fbcbe73f25de45c3bd69a3a
|
File details
Details for the file linsdex-0.1.0-py3-none-any.whl.
File metadata
- Download URL: linsdex-0.1.0-py3-none-any.whl
- Upload date:
- Size: 212.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f2acdef3c122594b2916c921c7c88b76ce3674d3e4beccca18e1cd0313a34ffb
|
|
| MD5 |
d0c5d8e340cf4bc0f0486291ffbb2218
|
|
| BLAKE2b-256 |
92a1ff834f17f192cd7a3b8e5e31e39811eaee3d8572de261a225007190a0aee
|