Scalable and simulation-free training of latent SDEs via score matching (Bartosh et al., 2025)
Project description
SDEmatching
This repo implements the 2025 paper SDE Matching: Scalable and Simulation-Free Training of Latent Stochastic Differential Equations, by Bartosh, Vetrov, and Naesseth [1]. The paper shows that it is possible to train a Stochastic Differential Equation (SDE) to fit data in a simulation free manner, by score matching between the trained model and a variational data approximation.
Score matching SDEs against variational approximation
The following is paraphrased from the paper [1]. We assume the existence of an SDE on the form
$$dz_t = f(z_t, t) dt + \sigma(z_t, t) dW_t$$
with prior distribution $p_0(z_0)$ and emission distribution $p_e(x\vert z_t)$. Here $dW_t$ is a Wiener process. The function $f(z_t, t)$ is called the drift and $\sigma(z_t, t)$ is the diffusion term, which is function that returns a matrix.
This process generates a set of series of observation, each of which of the form $X=[x_{t_1}, x_{t_2}, ..., x_{t_N}]$. In the following, we just assume that there is one single sampled series of observations, but the implementation works with a set of samples.
The problem, which the paper address is how to estimate the drift $f_\theta$, diffusion $\sigma_\theta$, prior $q_0$ and emission $q_e$ distributions such that the SDE
$$ dz_t = f_\theta(z_t, t) dt + \sigma_\theta(z_t, t) dW_t$$
generates similar data.
This can be done by defining a variational conditional marginal distribution that correpsonds to the latent states: $z_t = F_\phi(\varepsilon, t, X)$, where $\varepsilon\sim \mathcal{N}(0,I)$. This transformation implicitly defines the conditional distribution $q_\phi(z_t\vert X)$, where $z_t = F_\phi(\varepsilon, t, X)$ is a sample from it.
In this implementation $q_\phi$ is chosen to be Gaussian, conditional on the inputs $t$ and $X$, but this can be excanged with any distribution as long as $F$ is invertible in $\varepsilon$ and differentiable in $t$. Any normalizing flow conditional on $X$ and $t$ will do.
Now define the time derivative of $F_\phi$, given a fixed sample of $\varepsilon$:
$$\overline{f}{\phi}(z{t},t,X)=\frac{\partial F_{\phi}(\varepsilon,t,X)}{\partial t}\Big\vert_{\varepsilon=F_{\phi}^{-1}(z_{t},t,X)}. $$
Starting from $z_t \sim q(z_t\vert X)$ and integrating the ODE
$$ dz_t = \overline{f}{\phi}(z{t},t,X) dt $$
we then get a sample from the variational marginal distribution $q_\phi(z_t\vert X)$.
We are now interested in minimizing the KL-divergence between $q_\phi(z_t\vert X)$ and $p_\theta(z_t)$ over path measures. By Girsanovs theorem, this is finite if both processes share the same diffusion term $\sigma_\theta$.
Let $\sigma^2_\theta(z_t, t)$ be a shorthand for $\sigma_\theta(z_t, t)\sigma_\theta(z_t, t)^\top$. If we then define
$$ f_\phi(z_t, t, X) = \overline{f}{\phi}(z{t},t,X) + \frac{1}{2}\sigma^2_\theta(z_t, t) \nabla_{z_t} \ln q_\phi(z_t\vert X) + \frac{1}{2}\nabla_{z_t} \sigma^2_\theta(z_t, t),$$
then a result in [4] gives us that the SDE defiend by
$$ dz_t = f_\phi(z_t, t, X) dt + \sigma_\theta(z_t, t) dW_t $$
also has the marginal distribution $q_\phi(z_t\vert X)$, regardless of what $\sigma_\theta$ looks like.
Now, if we approximate $f_\phi(z_t, t, X)$ by a neural network $f_\theta(z_t, t)$, then we will in turn have an SDE which also approximately has the same marginal distribution as the variational marginal distribution $q_\phi(z_t\vert X)$.
This means that if we also get an approximation $q_e$ of the emission distribution, then we can couple everything together and calculate an Evidence Lower Bound as
$$ ELBO(\theta) = \mathcal{L}{\text{prior}} + \mathcal{L}{\text{diff}} + \mathcal{L}_{\text{rec}}, $$
where
$$\mathcal{L}{\text{prior}} = D{KL}(q_\phi(z_0|X) | p_\theta(z_0))$$
$$\mathcal{L}{\text{rec}} = -\log p\theta(x_{t_i}|z_{t_i})$$
$$\mathcal{L}{\text{diff}} = \tfrac{1}{2}|\sigma\theta^{-1}(z_t, t)(f_\theta(z_t, t) - f_\phi(z_t, t, X))|^2$$
Implementation
Ideas for further research
As explained above, the main obstacle with this model is how to construct the variational process. I have tried with a transformer model, which works for non-hidden states, but collapses when working with hidden latent dimensions.
In contrast, I have used a Gaussian Process, which generates an estimate of the mean and variance, but also generates an estimate of the derivative of the mean, and the variance of the derivative of the mean. This works very well for a simple physics system with restrictions to the drift matrix.
However, none of these generalizes. In the original paper, the authors use the ODE-RNN introduced by Rubanova et al. in 2. I personally do not like this approach, since it breaks with the simulation free spirit of the SDE matching method.
My suggestion for further research would be to use a 1D convolutional network that maps from observation_dim * series_length to latent_dim * series_length and then perhaps use a deep kernel 3 for further refinement.
How to run
In order to use the SDEMatching package do the following:
- Move to the folder where you want the code
- Clone this repository:
git clone https://github.com/simoneiriksson/SDEMatching.git - If you prefer, create new python environment:
python -m venv .venv - And activate the new python environment:
source .venv/bin/activate - Install the package into your active environment:
pip install -e .
References
-
Bartosh, G., Vetrov, D. & Naesseth, C.A.. (2025). SDE Matching: Scalable and Simulation-Free Training of Latent Stochastic Differential Equations. Proceedings of the 42nd International Conference on Machine Learning, in Proceedings of Machine Learning Research 267:3054-3070 Available from https://proceedings.mlr.press/v267/bartosh25a.html.
-
Yulia Rubanova, Ricky T. Q. Chen, and David Duvenaud. 2019. Latent ODEs for irregularly-sampled time series. Proceedings of the 33rd International Conference on Neural Information Processing Systems. Curran Associates Inc., Red Hook, NY, USA, Article 478, 5320–5330.
-
Wilson, A.G., Hu, Z., Salakhutdinov, R. & Xing, E.P.. (2016). Deep Kernel Learning. Proceedings of the 19th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 51:370-378 Available from https://proceedings.mlr.press/v51/wilson16.html.
-
Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sdematching-0.1.0.tar.gz.
File metadata
- Download URL: sdematching-0.1.0.tar.gz
- Upload date:
- Size: 23.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9a4f3de635aae329050158e50f622730a6558e8f08dc97b88df933f3dd307a35
|
|
| MD5 |
b2d105d041cbe24281ad8da80ac298a6
|
|
| BLAKE2b-256 |
af0d8abb8151865bbc92bccc5b2eb5d20c641bf028a661ae0216f77e7e043ca7
|
File details
Details for the file sdematching-0.1.0-py3-none-any.whl.
File metadata
- Download URL: sdematching-0.1.0-py3-none-any.whl
- Upload date:
- Size: 24.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
32b257253713fe114099bc273f493cd7ce5b445d0695e78665230b25ab531ee8
|
|
| MD5 |
4bbf15a424b547c9048eef38d713cfb8
|
|
| BLAKE2b-256 |
4726a6eaeedc3e1a6f3abe609d47999b3e55fff65082a8c0a8b2e1369faaeb04
|