Skip to main content

Score-based Diffusion models in JAX.

Project description

sbgm

Score-based Diffusion models in JAX

Implementation and extension of

and

in jax and equinox.

[!WARNING] :building_construction: Note this repository is under construction, expect changes. :building_construction:

Score-based diffusion models

Diffusion models are deep hierarchical models for data that use neural networks to model the reverse of a diffusion process that adds a sequence of noise perturbations to the data.

A diagram (see citations) showing how to map data to a noise distribution (the prior) with an SDE, and reverse this SDE for generative modeling. One can also reverse the associated probability flow ODE, which yields a deterministic process that samples from the same distribution as the SDE. Both the reverse-time SDE and probability flow ODE can be obtained by estimating the score.

Modern cutting-edge diffusion models (see citations) express both the forward and reverse diffusion processes as a Stochastic Differential Equation (SDE).

For any SDE of the form

$$ \text{d}\boldsymbol{x} = f(\boldsymbol{x}, t)\text{d}t + g(t)\text{d}\boldsymbol{w} $$

there exists an associated ordinary differential equation (ODE)

$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t $$

where the trajectories of the SDE and ODE have the same marginal PDFs $p_t(\boldsymbol{x})$.

Computing log-likelihoods with diffusion models

For each SDE there exists a deterministic ODE with marginal likelihoods $p_t(\boldsymbol{x})$ that match the SDE for all time $t$

$$ \text{d}\boldsymbol{x} = [f(\boldsymbol{x}, t)\text{d}t - \frac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})]\text{d}t = F(\boldsymbol{x}(t), t) $$

The continuous normalizing flow formalism allows the ODE to be expressed as

$$ \frac{\partial}{\partial t} \log p(\boldsymbol{x}(t)) = -\text{Tr}\bigg [ \frac{\partial}{\partial \boldsymbol{x}(t)} F(\boldsymbol{x}(t), t) \bigg ] $$

but note that maximum-likelihood training is prohibitively expensive for SDE based diffusion models.

Usage

Install via

pip install sbgm

to run

python main.py

or something like

import sbgm
import data
import configs

datasets_path = "."
root_dir = "."

config = configs.cifar10_config()

key = jr.key(config.seed)
data_key, model_key, train_key = jr.split(key, 3)

dataset = data.cifar10(datasets_path, data_key)

sharding = sbgm.shard.get_sharding()
    
# Diffusion model 
model = sbgm.models.get_model(
    model_key, 
    config.model.model_type, 
    dataset.data_shape, 
    dataset.context_shape, 
    dataset.parameter_dim,
    config
)

# Stochastic differential equation (SDE)
sde = sbgm.sde.get_sde(config.sde)

# Fit model to dataset
model = sbgm.train.train(
    train_key,
    model,
    sde,
    dataset,
    config,
    reload_opt_state=False,
    sharding=sharding,
    save_dir=root_dir
)

Features

  • Parallelised exact and approximate log-likelihood calculations,
  • UNet and transformer score network implementations,
  • VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
  • Multi-modal conditioning (basically just optional parameter and image conditioning methods),
  • Multi-device training and sampling.

Samples

[!NOTE] I haven't optimised any training/architecture hyperparameters or trained long enough here, you could do a lot better.

Flowers

Euler-Marayama sampling Flowers Euler-Marayama sampling

ODE sampling Flowers ODE sampling

CIFAR10

Euler-Marayama sampling CIFAR10 Euler-marayama sampling

ODE sampling CIFAR10 ODE sampling

SDEs

alt text

Citations

@misc{song2021scorebasedgenerativemodelingstochastic,
      title={Score-Based Generative Modeling through Stochastic Differential Equations}, 
      author={Yang Song and Jascha Sohl-Dickstein and Diederik P. Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
      year={2021},
      eprint={2011.13456},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2011.13456}, 
}
@misc{song2021maximumlikelihoodtrainingscorebased,
      title={Maximum Likelihood Training of Score-Based Diffusion Models}, 
      author={Yang Song and Conor Durkan and Iain Murray and Stefano Ermon},
      year={2021},
      eprint={2101.09258},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/2101.09258}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sbgm-0.0.11.tar.gz (20.0 kB view details)

Uploaded Source

Built Distribution

sbgm-0.0.11-py3-none-any.whl (29.8 kB view details)

Uploaded Python 3

File details

Details for the file sbgm-0.0.11.tar.gz.

File metadata

  • Download URL: sbgm-0.0.11.tar.gz
  • Upload date:
  • Size: 20.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for sbgm-0.0.11.tar.gz
Algorithm Hash digest
SHA256 d4fdb4dc96b312b37256b91ce719f3b0c9be61ca5ed5f26e17426107eef76d15
MD5 40c5e7b4bb0cd3c813df63a01d7956c1
BLAKE2b-256 39d43636fd01207e2846dcaba3b0c1958c8be27a06cdede69a5b5439f8fdb8ae

See more details on using hashes here.

File details

Details for the file sbgm-0.0.11-py3-none-any.whl.

File metadata

  • Download URL: sbgm-0.0.11-py3-none-any.whl
  • Upload date:
  • Size: 29.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.5

File hashes

Hashes for sbgm-0.0.11-py3-none-any.whl
Algorithm Hash digest
SHA256 255c9599f81f913e92ce8ec8646285e4216d2702e146da2b38659539f625225d
MD5 b0afa5df324e806062c820595f634c33
BLAKE2b-256 70caa1765eec4a961c6a2fd1e272045856485260e7b11033219834da693b2e8a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page