Score-based Diffusion models in JAX.
Project description
sbgm
Score-Based Diffusion Models in JAX
Implementation and extension of
and
in jax
and equinox
.
[!WARNING] :building_construction: Note this repository is under construction, expect changes. :building_construction:
Score-based diffusion models
Diffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data.
Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).
A diagram (see citations) showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.
For any SDE of the form
$$ \text{d}\boldsymbol{x} = f(\boldsymbol{x}, t)\text{d}t + g(t)\text{d}\boldsymbol{w}, $$
the reverse of the SDE from noise to data is given by
$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t) - g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t + g(t)\text{d}\boldsymbol{w}. $$
For every SDE there exists an associated ordinary differential equation (ODE)
$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t, $$
where the trajectories of the SDE and ODE have the same marginal PDFs $p_t(\boldsymbol{x})$.
The Stein score of the marginal probability distributions over $t$ is approximated with a neural network $\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})\approx s_{\theta}(\boldsymbol{x}(t), t)$. The parameters of the neural network are fit by minimising the score-matching loss.
Computing log-likelihoods with diffusion models
For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\boldsymbol{x})$ that match the SDE for all time $t$
$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = F(\boldsymbol{x}(t), t). $$
The continuous normalizing flow formalism allows the ODE to be expressed as
$$ \frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = -\text{Tr}\bigg [ \frac{\partial}{\partial \boldsymbol{x}(t)} F(\boldsymbol{x}(t), t) \bigg ], $$
but note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.
Usage
Install via
pip install sbgm
to run
python main.py
or something like
import sbgm
import data
import configs
datasets_path = "."
root_dir = "."
config = configs.cifar10_config()
key = jr.key(config.seed)
data_key, model_key, train_key = jr.split(key, 3)
dataset = data.cifar10(datasets_path, data_key)
sharding = sbgm.shard.get_sharding()
# Diffusion model
model = sbgm.models.get_model(
model_key,
config.model.model_type,
dataset.data_shape,
dataset.context_shape,
dataset.parameter_dim,
config
)
# Stochastic differential equation (SDE)
sde = sbgm.sde.get_sde(config.sde)
# Fit model to dataset
model = sbgm.train.train(
train_key,
model,
sde,
dataset,
config,
reload_opt_state=False,
sharding=sharding,
save_dir=root_dir
)
Features
- Parallelised exact and approximate log-likelihood calculations,
- UNet and transformer score network implementations,
- VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
- Multi-modal conditioning (basically just optional parameter and image conditioning methods),
- Checkpointing optimiser and model,
- Multi-device training and sampling.
Samples
[!NOTE] I haven't optimised any training/architecture hyperparameters or trained long enough here, you could do a lot better.
Flowers
Euler-Marayama sampling
ODE sampling
CIFAR10
Euler-Marayama sampling
ODE sampling
SDEs
Citations
@misc{song2021scorebasedgenerativemodelingstochastic,
title={Score-Based Generative Modeling through Stochastic Differential Equations},
author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
year={2021},
eprint={2011.13456},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2011.13456},
}
@misc{song2021maximumlikelihoodtrainingscorebased,
title={Maximum Likelihood Training of Score-Based Diffusion Models},
author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon},
year={2021},
eprint={2101.09258},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/2101.09258},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file sbgm-0.0.15.tar.gz
.
File metadata
- Download URL: sbgm-0.0.15.tar.gz
- Upload date:
- Size: 19.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eab0650113a98faffbdfe073a7b290c38fecae7937c7c94c40d5021656cb5b35 |
|
MD5 | 4bc91bc8c0ca91396b596b6540f4ef29 |
|
BLAKE2b-256 | f571faff081e60574433b2fee4ad70a1f63ab2ecabe946979eeed7598fcf8db3 |
File details
Details for the file sbgm-0.0.15-py3-none-any.whl
.
File metadata
- Download URL: sbgm-0.0.15-py3-none-any.whl
- Upload date:
- Size: 26.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.1 CPython/3.12.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f7d01915fa796c8743fbd4c75fb2be805d41f1a3465b8355100a89cb9012a7de |
|
MD5 | 9409937cb15bb32c64993018efeb22cc |
|
BLAKE2b-256 | 5658e604ca0211d60da271369cc42f49f58a1ec1ee21b4770575d736bccceaf4 |