Skip to main content

Scalable and simulation-free training of latent SDEs via score matching (Bartosh et al., 2025)

Project description

SDEmatching

This repo implements the 2025 paper SDE Matching: Scalable and Simulation-Free Training of Latent Stochastic Differential Equations, by Bartosh, Vetrov, and Naesseth [1]. The paper shows that it is possible to train a Stochastic Differential Equation (SDE) to fit data in a simulation free manner, by score matching between the trained model and a variational data approximation.

Score matching SDEs against variational approximation

The following is paraphrased from the paper [1]. We assume the existence of an SDE on the form

$$dz_t = f(z_t, t) dt + \sigma(z_t, t) dW_t$$

with prior distribution $p_0(z_0)$ and emission distribution $p_e(x\vert z_t)$. Here $dW_t$ is a Wiener process. The function $f(z_t, t)$ is called the drift and $\sigma(z_t, t)$ is the diffusion term, which is function that returns a matrix.

This process generates a set of series of observation, each of which of the form $X=[x_{t_1}, x_{t_2}, ..., x_{t_N}]$. In the following, we just assume that there is one single sampled series of observations, but the implementation works with a set of samples.

The problem, which the paper address is how to estimate the drift $f_\theta$, diffusion $\sigma_\theta$, prior $q_0$ and emission $q_e$ distributions such that the SDE

$$ dz_t = f_\theta(z_t, t) dt + \sigma_\theta(z_t, t) dW_t$$

generates similar data.

This can be done by defining a variational conditional marginal distribution that correpsonds to the latent states: $z_t = F_\phi(\varepsilon, t, X)$, where $\varepsilon\sim \mathcal{N}(0,I)$. This transformation implicitly defines the conditional distribution $q_\phi(z_t\vert X)$, where $z_t = F_\phi(\varepsilon, t, X)$ is a sample from it.

In this implementation $q_\phi$ is chosen to be Gaussian, conditional on the inputs $t$ and $X$, but this can be excanged with any distribution as long as $F$ is invertible in $\varepsilon$ and differentiable in $t$. Any normalizing flow conditional on $X$ and $t$ will do.

Now define the time derivative of $F_\phi$, given a fixed sample of $\varepsilon$:

$$\overline{f}{\phi}(z{t},t,X)=\frac{\partial F_{\phi}(\varepsilon,t,X)}{\partial t}\Big\vert_{\varepsilon=F_{\phi}^{-1}(z_{t},t,X)}. $$

Starting from $z_t \sim q(z_t\vert X)$ and integrating the ODE

$$ dz_t = \overline{f}{\phi}(z{t},t,X) dt $$

we then get a sample from the variational marginal distribution $q_\phi(z_t\vert X)$.

We are now interested in minimizing the KL-divergence between $q_\phi(z_t\vert X)$ and $p_\theta(z_t)$ over path measures. By Girsanovs theorem, this is finite if both processes share the same diffusion term $\sigma_\theta$.

Let $\sigma^2_\theta(z_t, t)$ be a shorthand for $\sigma_\theta(z_t, t)\sigma_\theta(z_t, t)^\top$. If we then define

$$ f_\phi(z_t, t, X) = \overline{f}{\phi}(z{t},t,X) + \frac{1}{2}\sigma^2_\theta(z_t, t) \nabla_{z_t} \ln q_\phi(z_t\vert X) + \frac{1}{2}\nabla_{z_t} \sigma^2_\theta(z_t, t),$$

then a result in [4] gives us that the SDE defiend by

$$ dz_t = f_\phi(z_t, t, X) dt + \sigma_\theta(z_t, t) dW_t $$

also has the marginal distribution $q_\phi(z_t\vert X)$, regardless of what $\sigma_\theta$ looks like.

Now, if we approximate $f_\phi(z_t, t, X)$ by a neural network $f_\theta(z_t, t)$, then we will in turn have an SDE which also approximately has the same marginal distribution as the variational marginal distribution $q_\phi(z_t\vert X)$.

This means that if we also get an approximation $q_e$ of the emission distribution, then we can couple everything together and calculate an Evidence Lower Bound as

$$ ELBO(\theta) = \mathcal{L}{\text{prior}} + \mathcal{L}{\text{diff}} + \mathcal{L}_{\text{rec}}, $$

where

$$\mathcal{L}{\text{prior}} = D{KL}(q_\phi(z_0|X) | p_\theta(z_0))$$

$$\mathcal{L}{\text{rec}} = -\log p\theta(x_{t_i}|z_{t_i})$$

$$\mathcal{L}{\text{diff}} = \tfrac{1}{2}|\sigma\theta^{-1}(z_t, t)(f_\theta(z_t, t) - f_\phi(z_t, t, X))|^2$$

Implementation

Ideas for further research

As explained above, the main obstacle with this model is how to construct the variational process. I have tried with a transformer model, which works for non-hidden states, but collapses when working with hidden latent dimensions.

In contrast, I have used a Gaussian Process, which generates an estimate of the mean and variance, but also generates an estimate of the derivative of the mean, and the variance of the derivative of the mean. This works very well for a simple physics system with restrictions to the drift matrix.

However, none of these generalizes. In the original paper, the authors use the ODE-RNN introduced by Rubanova et al. in 2. I personally do not like this approach, since it breaks with the simulation free spirit of the SDE matching method.

My suggestion for further research would be to use a 1D convolutional network that maps from observation_dim * series_length to latent_dim * series_length and then perhaps use a deep kernel 3 for further refinement.

How to run

In order to use the SDEMatching package do the following:

  1. Move to the folder where you want the code
  2. Clone this repository: git clone https://github.com/simoneiriksson/SDEMatching.git
  3. If you prefer, create new python environment: python -m venv .venv
  4. And activate the new python environment: source .venv/bin/activate
  5. Install the package into your active environment: pip install -e .

References

  1. Bartosh, G., Vetrov, D. & Naesseth, C.A.. (2025). SDE Matching: Scalable and Simulation-Free Training of Latent Stochastic Differential Equations. Proceedings of the 42nd International Conference on Machine Learning, in Proceedings of Machine Learning Research 267:3054-3070 Available from https://proceedings.mlr.press/v267/bartosh25a.html.

  2. Yulia Rubanova, Ricky T. Q. Chen, and David Duvenaud. 2019. Latent ODEs for irregularly-sampled time series. Proceedings of the 33rd International Conference on Neural Information Processing Systems. Curran Associates Inc., Red Hook, NY, USA, Article 478, 5320–5330.

  3. Wilson, A.G., Hu, Z., Salakhutdinov, R. & Xing, E.P.. (2016). Deep Kernel Learning. Proceedings of the 19th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 51:370-378 Available from https://proceedings.mlr.press/v51/wilson16.html.

  4. Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sdematching-0.1.0.tar.gz (23.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sdematching-0.1.0-py3-none-any.whl (24.1 kB view details)

Uploaded Python 3

File details

Details for the file sdematching-0.1.0.tar.gz.

File metadata

  • Download URL: sdematching-0.1.0.tar.gz
  • Upload date:
  • Size: 23.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for sdematching-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9a4f3de635aae329050158e50f622730a6558e8f08dc97b88df933f3dd307a35
MD5 b2d105d041cbe24281ad8da80ac298a6
BLAKE2b-256 af0d8abb8151865bbc92bccc5b2eb5d20c641bf028a661ae0216f77e7e043ca7

See more details on using hashes here.

File details

Details for the file sdematching-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: sdematching-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 24.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for sdematching-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 32b257253713fe114099bc273f493cd7ce5b445d0695e78665230b25ab531ee8
MD5 4bbf15a424b547c9048eef38d713cfb8
BLAKE2b-256 4726a6eaeedc3e1a6f3abe609d47999b3e55fff65082a8c0a8b2e1369faaeb04

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page