Skip to main content

A minimal but functional implementation of diffusion model training and sampling

Project description

smalldiffusion

Tutorial blog post Paper link Pypi project Build Status

A lightweight diffusion library for training and sampling from diffusion models. It is built for easy experimentation when training new models and developing new samplers, supporting minimal toy models to state-of-the-art pretrained models. The core of this library for diffusion training and sampling is implemented in less than 100 lines of very readable pytorch code. To install from pypi:

pip install smalldiffusion

Toy models

To train and sample from the Swissroll toy dataset in 10 lines of code (see examples/toyexample.ipynb for a detailed guide):

from torch.utils.data import DataLoader
from smalldiffusion import Swissroll, TimeInputMLP, ScheduleLogLinear, training_loop, samples

dataset  = Swissroll(np.pi/2, 5*np.pi, 100)
loader   = DataLoader(dataset, batch_size=2048)
model    = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
trainer  = training_loop(loader, model, schedule, epochs=15000)
losses   = [ns.loss.item() for ns in trainer]
*xt, x0  = samples(model, schedule.sample_sigmas(20), gam=2)

Results on various toy datasets:

U-Net models

The same code can be used to train U-Net-based models. To train a model on the FashionMNIST dataset and generate a batch of samples (after first running accelerate config):

accelerate launch examples/fashion_mnist.py

With the provided default parameters and training on a single GPU for around 2 hours, the model can achieve a FID score of around 12-13, producing the following generated outputs:

StableDiffusion

smalldiffusion's sampler works with any pretrained diffusion model, and supports DDPM, DDIM as well as accelerated sampling algorithms. In examples/diffusers_wrapper.py, we provide a simple wrapper for any pretrained huggingface diffusers latent diffusion model, enabling sampling from pretrained models with only a few lines of code:

from diffusers_wrapper import ModelLatentDiffusion
from smalldiffusion import ScheduleLDM, samples

schedule = ScheduleLDM(1000)
model    = ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
model.set_text_condition('An astronaut riding a horse')
*xts, x0 = samples(model, schedule.sample_sigmas(50))
decoded  = model.decode_latents(x0)

It is easy to experiment with different sampler parameters and sampling schedules, as demonstrated in examples/stablediffusion.py. A few examples on tweaking the parameter gam:

How to use

The core of smalldiffusion depends on the interaction between data, model and schedule objects. Here we give a specification of these objects. For a detailed introduction to diffusion models and the notation used in the code, see the accompanying tutorial.

Data

For training diffusion models, smalldiffusion supports pytorch Datasets and DataLoaders. The training code expects the iterates from a DataLoader object to be batches of data, without labels. To remove labels from existing datasets, extract the data with the provided MappedDataset wrapper before constructing a DataLoader.

Two toy datasets, Swissroll and DatasaurusDozen, are provided.

Model

All model objects should be a subclass of torch.nn.Module. Models should have:

  • A parameter input_dims, a tuple containing the dimensions of the input to the model (not including batch-size).
  • A method rand_input(batchsize) which takes in a batch-size and returns an i.i.d. standard normal random input with shape [batchsize, *input_dims]. This method can be inherited from the provided ModelMixin class when the input_dims parameter is set.

Models are called with two arguments:

  • x is a batch of data of batch-size B and shape [B, *model.input_dims].
  • sigma is either a singleton or a batch.
    1. If sigma.shape == [], the same value will be used for each x.
    2. Otherwise sigma.shape == [B, 1, ..., 1], and x[i] will be paired with sigma[i].

Models should return a predicted noise value with the same shape as x.

Schedule

A Schedule object determines the rate at which the noise level sigma increases during the diffusion process. It is constructed by simply passing in a tensor of increasing sigma values. Schedule objects have the methods

  • sample_sigmas(steps) which subsamples the schedule for sampling.
  • sample_batch(batchsize) which generates batch of sigma values selected uniformly at random, for use in training.

Three schedules are provided:

  1. ScheduleLogLinear is a simple schedule which works well on small datasets and toy models.
  2. ScheduleDDPM is commonly used in pixel-space image diffusion models.
  3. ScheduleLDM is commonly used in latent diffusion models, e.g. StableDiffusion.

The following plot shows these three schedules with default parameters.

Training

The training_loop generator function provides a simple training loop for training a diffusion model , given loader, model and schedule objects described above. It yields a namespace with the local variables, for easy evaluation during training. For example, to print out the loss every iteration:

for ns in training_loop(loader, model, schedule):
    print(ns.loss.item())

Multi-GPU training and sampling is also supported via accelerate.

Sampling

To sample from a diffusion model, the samples generator function takes in a model and a decreasing list of sigmas to use during sampling. This list is usually created by calling the sample_sigmas(steps) method of a Schedule object. The generator will yield a sequence of xts produced during sampling. The sampling loop generalizes most commonly-used samplers:

For more details on how these sampling algorithms can be simplified, generalized and implemented in only 5 lines of code, see Appendix A of [Permenter and Yuan].

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smalldiffusion-0.3.tar.gz (16.5 kB view hashes)

Uploaded Source

Built Distribution

smalldiffusion-0.3-py3-none-any.whl (12.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page