Skip to main content

A minimal but functional implementation of diffusion model training and sampling

Project description

smalldiffusion

Tutorial blog post Paper link Open in Colab Pypi project Build Status

A lightweight diffusion library for training and sampling from diffusion models. It is built for easy experimentation when training new models and developing new samplers, supporting minimal toy models to state-of-the-art pretrained models. The core of this library for diffusion training and sampling is implemented in less than 100 lines of very readable pytorch code. To install from pypi:

pip install smalldiffusion

Toy models

To train and sample from the Swissroll toy dataset in 10 lines of code (see examples/toyexample.ipynb for a detailed guide):

from torch.utils.data import DataLoader
from smalldiffusion import Swissroll, TimeInputMLP, ScheduleLogLinear, training_loop, samples

dataset  = Swissroll(np.pi/2, 5*np.pi, 100)
loader   = DataLoader(dataset, batch_size=2048)
model    = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
trainer  = training_loop(loader, model, schedule, epochs=15000)
losses   = [ns.loss.item() for ns in trainer]
*xt, x0  = samples(model, schedule.sample_sigmas(20), gam=2)

Results on various toy datasets:

Conditional training and sampling with classifier-free guidance

We can also train conditional diffusion models and sample from them using classifier-free guidance. In examples/cond_tree_model.ipynb, samples from each class in the 2D tree dataset are represented with a different color.

Diffusion transformer

We provide a concise implementation of the diffusion transformer introduced in [Peebles and Xie 2022]. To train a model on the FashionMNIST dataset and generate a batch of samples (after first running accelerate config):

accelerate launch examples/fashion_mnist_dit.py

With the provided default parameters and training on a single GPU for around 2 hours, the model can achieve a FID score of around 5-6, producing the following generated outputs:

U-Net models

The same code can be used to train U-Net-based models.

accelerate launch examples/fashion_mnist_unet.py

We also provide example code to train a U-Net on the CIFAR-10 dataset, with an unconditional generation FID of around 3-4:

accelerate launch examples/cifar_unet.py

StableDiffusion

smalldiffusion's sampler works with any pretrained diffusion model, and supports DDPM, DDIM as well as accelerated sampling algorithms. In examples/diffusers_wrapper.py, we provide a simple wrapper for any pretrained huggingface diffusers latent diffusion model, enabling sampling from pretrained models with only a few lines of code:

from diffusers_wrapper import ModelLatentDiffusion
from smalldiffusion import ScheduleLDM, samples

schedule = ScheduleLDM(1000)
model    = ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
model.set_text_condition('An astronaut riding a horse')
*xts, x0 = samples(model, schedule.sample_sigmas(50))
decoded  = model.decode_latents(x0)

It is easy to experiment with different sampler parameters and sampling schedules, as demonstrated in examples/stablediffusion.py. A few examples on tweaking the parameter gam:

How to use

The core of smalldiffusion depends on the interaction between data, model and schedule objects. Here we give a specification of these objects. For a detailed introduction to diffusion models and the notation used in the code, see the accompanying tutorial.

Data

For training diffusion models, smalldiffusion supports pytorch Datasets and DataLoaders. The training code expects the iterates from a DataLoader object to be batches of data, without labels. To remove labels from existing datasets, extract the data with the provided MappedDataset wrapper before constructing a DataLoader.

Three 2D toy datasets, Swissroll, DatasaurusDozen, and TreeDatasetare provided.

Model

All model objects should be a subclass of torch.nn.Module. Models should have:

  • A parameter input_dims, a tuple containing the dimensions of the input to the model (not including batch-size).
  • A method rand_input(batchsize) which takes in a batch-size and returns an i.i.d. standard normal random input with shape [batchsize, *input_dims]. This method can be inherited from the provided ModelMixin class when the input_dims parameter is set.

Models are called with arguments:

  • x is a batch of data of batch-size B and shape [B, *model.input_dims].
  • sigma is either a singleton or a batch.
    1. If sigma.shape == [], the same value will be used for each x.
    2. Otherwise sigma.shape == [B, 1, ..., 1], and x[i] will be paired with sigma[i].
  • Optionally, cond of shape [B, ...], if the model is conditional.

Models should return a predicted noise value with the same shape as x.

Schedule

A Schedule object determines the rate at which the noise level sigma increases during the diffusion process. It is constructed by simply passing in a tensor of increasing sigma values. Schedule objects have the methods

  • sample_sigmas(steps) which subsamples the schedule for sampling.
  • sample_batch(batchsize) which generates batch of sigma values selected uniformly at random, for use in training.

The following schedules are provided:

  1. ScheduleLogLinear is a simple schedule which works well on small datasets and toy models.
  2. ScheduleDDPM is commonly used in pixel-space image diffusion models.
  3. ScheduleLDM is commonly used in latent diffusion models, e.g. StableDiffusion.
  4. ScheduleSigmoid introduced in GeoDiff for molecular conformal generation
  5. ScheduleCosine introduced in iDDPM

The following plot shows these three schedules with default parameters.

Training

The training_loop generator function provides a simple training loop for training a diffusion model , given loader, model and schedule objects described above. It yields a namespace with the local variables, for easy evaluation during training. For example, to print out the loss every iteration:

for ns in training_loop(loader, model, schedule):
    print(ns.loss.item())

Multi-GPU training and sampling is also supported via accelerate.

Sampling

To sample from a diffusion model, the samples generator function takes in a model and a decreasing list of sigmas to use during sampling. This list is usually created by calling the sample_sigmas(steps) method of a Schedule object. The generator will yield a sequence of xts produced during sampling. The sampling loop generalizes most commonly-used samplers:

For more details on how these sampling algorithms can be simplified, generalized and implemented in only 5 lines of code, see Appendix A of [Permenter and Yuan].

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smalldiffusion-0.4.4.tar.gz (23.0 kB view details)

Uploaded Source

Built Distribution

smalldiffusion-0.4.4-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file smalldiffusion-0.4.4.tar.gz.

File metadata

  • Download URL: smalldiffusion-0.4.4.tar.gz
  • Upload date:
  • Size: 23.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for smalldiffusion-0.4.4.tar.gz
Algorithm Hash digest
SHA256 7f55d5f0f1314666711e8848850a305d475ab346d40e0c8f78f72ad8c4d6cca6
MD5 fc5f9374f845c73380936ccea1df730c
BLAKE2b-256 5b33236183e7df33d7dd99b4f298c84ad89f07e885ac76b2e80e1abe9800a447

See more details on using hashes here.

File details

Details for the file smalldiffusion-0.4.4-py3-none-any.whl.

File metadata

  • Download URL: smalldiffusion-0.4.4-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for smalldiffusion-0.4.4-py3-none-any.whl
Algorithm Hash digest
SHA256 3beac7b8ffbda7b2857928043c2ae2efd52cdc81d0905faa75aa2743e1f3ef07
MD5 29578304fd8983b3ac6f4d4f1e73eeb7
BLAKE2b-256 9fb4f228777304a6795df6ceaf47a9f0bff5661d477231e88969befd67a2f995

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page