Skip to main content

Score-based Diffusion models in JAX.

Project description

sbgm

Score-Based Diffusion Models in JAX

Implementation and extension of Score-Based Generative Modeling through Stochastic Differential Equations (Song++20) and Maximum Likelihood Training of Score-Based Diffusion Models (Song++21) in jax and equinox.

This repository provides a lightweight library of models, sampling and likelihood routines. Suitable for likelihood-free or emulation based approaches. Tested and typed code to ensure reliable and benchmarkable training and inference.

[!WARNING] :building_construction: Note this repository is under construction, expect changes. :building_construction:

Score-based diffusion models

Diffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data.

Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).


A diagram showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.


For any SDE of the form

$$ \text{d}x_t = f(x_t, t)\text{d}t + g(t)\text{d}w_t, $$

the reverse of the SDE from noise to data is given by

$$ \text{d}x_t = [f(x_t, t) - g(t)^2\nabla_{x_t}\log p_t(x_t)]\text{d}t + g(t)\text{d}w_t, $$

Where $\text{d}w_t \sim \mathcal{G}[\text{d}w_t | 0, 1]$. For every SDE there exists an associated ordinary differential equation (ODE)

$$ \text{d}x_t = [f(x_t, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{x_t}\log p_t(x_t)]\text{d}t, $$

where the trajectories of the SDE and ODE have the same marginal PDFs $p_t(x_t)$.

The Stein score of the marginal probability distributions over $t$ is approximated with a neural network $\nabla_{x_t}\log p_t(x_t)\approx s_{\theta}(x_t, t)$. The parameters of the neural network are fit by minimising the score-matching loss.

Computing log-likelihoods with diffusion models

For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(x_t)$ that match the SDE for all time $t$

$$ \text{d}x_t = [f(x_t, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{x_t}\log p_t(x_t)]\text{d}t = f'(x_t, t)\text{d}t. $$

The continuous normalizing flow formalism allows the ODE to be expressed as

$$ \frac{\partial}{\partial t} \log p_t(x_t) = \nabla_{x_t} \cdot f'(x_t, t), $$

which gives the log-likelihood of a datapoint $x$ as

$$ \log p(x) = \log p_T(x_T) - \int_{t=0}^{t=T}\text{d}t ; \nabla_{x_t}\cdot f'(x_t, t). $$

Note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.

Usage

Install via

pip install sbgm

and for the examples, run

pip install .[examples] 

To fit a diffusion model to the cifar10 image dataset, try something like

import sbgm
import configs

datasets_path = "./"
root_dir = "./"

config = configs.cifar10_config()

key = jr.key(config.seed)
data_key, model_key, train_key = jr.split(key, 3)

dataset = sbgm.data.cifar10(datasets_path, data_key)

sharding = sbgm.shard.get_sharding()
    
# Diffusion model 
model = sbgm.models.get_model(
    model_key, 
    config.model.model_type, 
    dataset.data_shape, 
    dataset.context_shape, 
    dataset.parameter_dim,
    config
)

# Stochastic differential equation (SDE)
sde = sbgm.sde.get_sde(config.sde)

# Fit model to dataset
model = sbgm.train.train(
    train_key,
    model,
    sde,
    dataset,
    config,
    sharding=sharding,
    save_dir=root_dir
)

Features

  • Parallelised exact and approximate log-likelihood calculations,
  • UNet and transformer score network implementations,
  • VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
  • Multi-modal conditioning (basically just optional parameter and image conditioning methods),
  • Checkpointing for optimiser and model,
  • Multi-device training and sampling.

Samples

[!NOTE] I haven't optimised any training/architecture hyperparameters or trained long enough here, you could do a lot better.

Flowers

Euler-Marayama sampling Flowers Euler-Marayama sampling

ODE sampling Flowers ODE sampling

CIFAR10

Euler-Marayama sampling CIFAR10 Euler-marayama sampling

ODE sampling CIFAR10 ODE sampling

SDEs

Below are the most common SDE parameterisations for Gaussian probability paths, you can easily add your own!

alt text

Contributing

Want to add something? See CONTRIBUTING.md.

Citations

@misc{song2021scorebasedgenerativemodelingstochastic,
      title={Score-Based Generative Modeling through Stochastic Differential Equations}, 
      author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
      year={2021},
      eprint={2011.13456},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2011.13456}, 
}
@misc{song2021maximumlikelihoodtrainingscorebased,
      title={Maximum Likelihood Training of Score-Based Diffusion Models}, 
      author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon},
      year={2021},
      eprint={2101.09258},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/2101.09258}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sbgm-0.0.38.tar.gz (36.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sbgm-0.0.38-py3-none-any.whl (50.6 kB view details)

Uploaded Python 3

File details

Details for the file sbgm-0.0.38.tar.gz.

File metadata

  • Download URL: sbgm-0.0.38.tar.gz
  • Upload date:
  • Size: 36.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sbgm-0.0.38.tar.gz
Algorithm Hash digest
SHA256 d778298a64ec8b398d3475df6afdc30066c5224a9167aad05b82a052799fc4e2
MD5 d18c3eb60a72acbd2a5a84d02e34b140
BLAKE2b-256 cb5d339e02f262179df1586f508f3d9d705eee54ed7f4bcadbcc5ee59eedbef4

See more details on using hashes here.

Provenance

The following attestation bundles were made for sbgm-0.0.38.tar.gz:

Publisher: publish.yml on homerjed/sbgm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file sbgm-0.0.38-py3-none-any.whl.

File metadata

  • Download URL: sbgm-0.0.38-py3-none-any.whl
  • Upload date:
  • Size: 50.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for sbgm-0.0.38-py3-none-any.whl
Algorithm Hash digest
SHA256 1eba9765f54a1622fc5fdb0f1aba6d420a36f72f942e73db59c29ab1888e6ac9
MD5 af04e083ca75d45b46607bf722b2221a
BLAKE2b-256 49a96f6378f8d122920759704a9f385fcd2d740951d6f9e383180472571881ff

See more details on using hashes here.

Provenance

The following attestation bundles were made for sbgm-0.0.38-py3-none-any.whl:

Publisher: publish.yml on homerjed/sbgm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page