Skip to main content

A minimal but functional implementation of diffusion model training and sampling

Project description

smalldiffusion

A lightweight diffusion library for training and sampling from diffusion models. It is built for easy experimentation when training new models and developing new samplers, supporting minimal toy models to state-of-the-art pretrained models. The core of this library for diffusion training and sampling is implemented in less than 100 lines of very readable pytorch code. To install:

pip install smalldiffusion

Toy models

To train and sample from the Swissroll toy dataset in 10 lines of code (see examples/toyexample.ipynb for a detailed guide):

from torch.utils.data import DataLoader
from smalldiffusion import Swissroll, TimeInputMLP, ScheduleLogLinear, training_loop, samples

dataset  = Swissroll(np.pi/2, 5*np.pi, 100)
loader   = DataLoader(dataset, batch_size=2048)
model    = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
schedule = ScheduleLogLinear(N=200)
trainer  = training_loop(loader, model, schedule, epochs=10000)
losses   = [ns.loss.item() for ns in trainer]
*xt, x0  = samples(model, schedule.sample_sigmas(20))

U-Net models

The same code can be used to train U-Net-based models. To train a model on the FashionMNIST dataset and generate a batch of samples (after first running accelerate config):

accelerate launch examples/fashion_mnist.py

With the provided default parameters and training on a single GPU for around 2 hours, the model can achieve a FID score of around 12-13, producing the following generated outputs:

StableDiffusion

smalldiffusion's sampler works with any pretrained diffusion model, and supports DDPM, DDIM as well as accelerated sampling algorithms. In examples/diffusers_wrapper.py, we provide a simple wrapper for any pretrained huggingface diffusers latent diffusion model, enabling sampling from pretrained models with only a few lines of code

from diffusers_wrapper import ModelLatentDiffusion
from smalldiffusion import ScheduleLDM, samples

schedule = ScheduleLDM(1000)
model    = ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
model.set_text_condition('An astronaut riding a horse')
*xts, x0 = samples(model, schedule.sample_sigmas(50))
decoded  = model.decode_latents(x0)

It is easy to experiment with different sampler parameters and sampling schedules, as demonstrated in examples/stablediffusion.py.

How to use

The core of smalldiffusion depends on the interaction between data, model and schedule objects. Here we give a specification of these objects. For a detailed introduction to diffusion models and the notation used in the code, see the accompanying tutorial.

Data

For training diffusion models, smalldiffusion supports pytorch Datasets and DataLoaders. The training code expects the iterates from a DataLoader object to be batches of data, without labels. To remove labels from existing datasets, extract the data with the provided MappedDataset wrapper before constructing a DataLoader.

Model

All model objects should be a subclass of torch.nn.Module. Models should have:

  • A parameter input_dims, a tuple containing the dimensions of the input to the model (not including batch-size).
  • A method rand_input(batchsize) which takes in a batch-size and returns an i.i.d. standard normal random input with shape [batchsize, *input_dims]. This method can be inherited from the provided ModelMixin class when the input_dims parameter is set.

Models are called with two arguments:

  • x is a batch of data of batch-size B and shape [B, *model.input_dims].
  • sigma is either a singleton or a batch.
    1. If sigma.shape == [], the same value will be used for each x.
    2. Otherwise sigma.shape == [B, 1, ..., 1], and x[i] will be paired with sigma[i].

Models should return a predicted noise value with the same shape as x.

Schedule

A Schedule object determines the rate at which the noise level sigma increases during the diffusion process. It is constructed by simply passing in a tensor of increasing sigma values. Schedule objects have the methods

  • sample_sigmas(steps) which subsamples the schedule for sampling.
  • sample_batch(batchsize) which generates batch of sigma values selected uniformly at random, for use in training.

Three schedules are provided:

  1. ScheduleLogLinear is a simple schedule which works well on small datasets and toy models.
  2. ScheduleDDPM is commonly used in pixel-space image diffusion models.
  3. ScheduleLDM is commonly used in latent diffusion models, e.g. StableDiffusion.

The following plot shows these three schedules with default parameters.

Training

The training_loop generator function provides a simple training loop for training a diffusion model , given loader, model and schedule objects described above. It yields a namespace with the local variables, for easy evaluation during training. For example, to print out the loss every iteration:

for ns in training_loop(loader, model, schedule):
    print(ns.loss.item())

Multi-GPU training and sampling is also supported via accelerate.

Sampling

To sample from a diffusion model, the samples generator function takes in a model and a decreasing list of sigmas to use during sampling. This list is usually created by calling the sample_sigmas(steps) method of a Schedule object. The generator will yield a sequence of xts produced during sampling. The sampling loop generalizes most commonly-used samplers:

For more details on how these sampling algorithms can be simplified, generalized and implemented in only 5 lines of code, see Appendix A of [Permenter and Yuan].

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smalldiffusion-0.1.tar.gz (12.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

smalldiffusion-0.1-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file smalldiffusion-0.1.tar.gz.

File metadata

  • Download URL: smalldiffusion-0.1.tar.gz
  • Upload date:
  • Size: 12.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for smalldiffusion-0.1.tar.gz
Algorithm Hash digest
SHA256 2ba294dac33b479a239769ef0140467171016d6c7fd2f50681fcce1915a64062
MD5 2229c859e8a2d150ad3145867504eae7
BLAKE2b-256 224c7ea4860ba3c161290eded0981948234dfa164bfbb1f9621f139f85b56f1b

See more details on using hashes here.

File details

Details for the file smalldiffusion-0.1-py3-none-any.whl.

File metadata

  • Download URL: smalldiffusion-0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for smalldiffusion-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1e8b0c46f2e66e807cdfb798fc47b881585c2e8b0de4a3b35a9f0ef0913fe03f
MD5 9c6daf15994d68e787900624b95555e6
BLAKE2b-256 476e6e0c5fca0b00b9e8728d8216583a76e1600d3af26aefe7be9a7cf0e6a9e0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page