Skip to main content

A minimal but functional implementation of diffusion model training and sampling

Project description

smalldiffusion

Tutorial blog post Paper link Open in Colab Pypi project Build Status

A lightweight diffusion library for training and sampling from diffusion and flow models. Features:

  • Designed for ease of experimentation when training new models or developing new samplers
  • Dataset support: 2D toy datasets, pixel and latent-space image datasets
  • Example training code (with close to SOTA FID): FashionMNIST, CIFAR10, Imagenet
  • Models: MLP, U-Net and DiT
  • Supports multiple parameterizations: score-, flow- or data-prediction
  • Small but extensible core: less than 100 lines of code for training and sampling

To install from pypi:

pip install smalldiffusion

For local development with uv:

uv sync --extra dev --extra test --extra examples
uv run pytest

Toy models

To train and sample from the Swissroll toy dataset in 10 lines of code (see examples/toyexample.ipynb for a detailed guide):

from torch.utils.data import DataLoader
from smalldiffusion import Swissroll, TimeInputMLP, ScheduleLogLinear, training_loop, samples

dataset  = Swissroll(np.pi/2, 5*np.pi, 100)
loader   = DataLoader(dataset, batch_size=2048)
model    = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
trainer  = training_loop(loader, model, schedule, epochs=15000)
losses   = [ns.loss.item() for ns in trainer]
*xt, x0  = samples(model, schedule.sample_sigmas(20), gam=2)

Results on various toy datasets:

Conditional training and sampling with classifier-free guidance

We can also train conditional diffusion models and sample from them using classifier-free guidance. In examples/cond_tree_model.ipynb, samples from each class in the 2D tree dataset are represented with a different color.

Diffusion transformer

We provide a concise implementation of the diffusion transformer introduced in [Peebles and Xie 2022].

DiT on ImageNet with flow matching

We provide an example script for training a DiT-B/2 model on ImageNet 256×256 using the flow matching formulation in the latent space of Stable Diffusion's VAE. The script trains on precomputed VAE latents and supports multi-GPU training via accelerate:

uv run accelerate config
uv run accelerate launch examples/imagenet_dit.py

After training for 400k steps (~10 hours on 8 GPUs), the model achieves an unconditional FID of around 27, compared to 33 for SiT and 43 for DiT.

FashionMNIST dataset

To train a diffusion transformer model on the FashionMNIST dataset and generate a batch of samples (after first running uv run accelerate config):

uv run accelerate launch examples/fashion_mnist_dit.py

With the provided default parameters and training on a single GPU for around 2 hours, the model can achieve a FID score of around 5-6, producing the following generated outputs:

U-Net models

The same code can be used to train U-Net-based models.

uv run accelerate launch examples/fashion_mnist_unet.py

We also provide example code to train a U-Net on the CIFAR-10 dataset, with an unconditional generation FID of around 3-4:

uv run accelerate launch examples/cifar_unet.py

StableDiffusion

smalldiffusion's sampler works with any pretrained diffusion model, and supports DDPM, DDIM as well as accelerated sampling algorithms. In examples/diffusers_wrapper.py, we provide a simple wrapper for any pretrained huggingface diffusers latent diffusion model, enabling sampling from pretrained models with only a few lines of code:

from diffusers_wrapper import ModelLatentDiffusion
from smalldiffusion import ScheduleLDM, samples

schedule = ScheduleLDM(1000)
model    = ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')
model.set_text_condition('An astronaut riding a horse')
*xts, x0 = samples(model, schedule.sample_sigmas(50))
decoded  = model.decode_latents(x0)

It is easy to experiment with different sampler parameters and sampling schedules, as demonstrated in examples/stablediffusion.py. A few examples on tweaking the parameter gam:

How to use

The core of smalldiffusion depends on the interaction between data, model and schedule objects. Here we give a specification of these objects. For a detailed introduction to diffusion models and the notation used in the code, see the accompanying tutorial.

Data

For training diffusion models, smalldiffusion supports pytorch Datasets and DataLoaders. The training code expects the iterates from a DataLoader object to be batches of data, without labels. To remove labels from existing datasets, extract the data with the provided MappedDataset wrapper before constructing a DataLoader.

Three 2D toy datasets, Swissroll, DatasaurusDozen, and TreeDatasetare provided.

Model

All model objects should be a subclass of torch.nn.Module. Models should have:

  • A parameter input_dims, a tuple containing the dimensions of the input to the model (not including batch-size).
  • A method rand_input(batchsize) which takes in a batch-size and returns an i.i.d. standard normal random input with shape [batchsize, *input_dims]. This method can be inherited from the provided ModelMixin class when the input_dims parameter is set.

Models are called with arguments:

  • x is a batch of data of batch-size B and shape [B, *model.input_dims].
  • sigma is either a singleton or a batch.
    1. If sigma.shape == [], the same value will be used for each x.
    2. Otherwise sigma.shape == [B, 1, ..., 1], and x[i] will be paired with sigma[i].
  • Optionally, cond of shape [B, ...], if the model is conditional.

Models should return a predicted noise value with the same shape as x.

Schedule

A Schedule object determines the rate at which the noise level sigma increases during the diffusion process. It is constructed by simply passing in a tensor of increasing sigma values. Schedule objects have the methods

  • sample_sigmas(steps) which subsamples the schedule for sampling.
  • sample_batch(batchsize) which generates batch of sigma values selected uniformly at random, for use in training.

The following schedules are provided:

  1. ScheduleLogLinear is a simple schedule which works well on small datasets and toy models.
  2. ScheduleDDPM is commonly used in pixel-space image diffusion models.
  3. ScheduleLDM is commonly used in latent diffusion models, e.g. StableDiffusion.
  4. ScheduleSigmoid introduced in GeoDiff for molecular conformal generation
  5. ScheduleCosine introduced in iDDPM

The following plot shows these three schedules with default parameters.

Training

The training_loop generator function provides a simple training loop for training a diffusion model , given loader, model and schedule objects described above. It yields a namespace with the local variables, for easy evaluation during training. For example, to print out the loss every iteration:

for ns in training_loop(loader, model, schedule):
    print(ns.loss.item())

Multi-GPU training and sampling is also supported via accelerate.

Sampling

To sample from a diffusion model, the samples generator function takes in a model and a decreasing list of sigmas to use during sampling. This list is usually created by calling the sample_sigmas(steps) method of a Schedule object. The generator will yield a sequence of xts produced during sampling. The sampling loop generalizes most commonly-used samplers:

For more details on how these sampling algorithms can be simplified, generalized and implemented in only 5 lines of code, see Appendix A of [Permenter and Yuan].

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smalldiffusion-0.5.0.tar.gz (15.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

smalldiffusion-0.5.0-py3-none-any.whl (18.6 kB view details)

Uploaded Python 3

File details

Details for the file smalldiffusion-0.5.0.tar.gz.

File metadata

  • Download URL: smalldiffusion-0.5.0.tar.gz
  • Upload date:
  • Size: 15.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Arch Linux","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for smalldiffusion-0.5.0.tar.gz
Algorithm Hash digest
SHA256 976b455f06de29c0e672002e52a5f9dd1a8e4c360449e681c9ce01cb75553e60
MD5 77a22ddd8b45ed99fca0ca048346e6ad
BLAKE2b-256 4dee521cb1842197ae9ea178ac5ad337d5d573abdccd014b8b8a0722bd80ab98

See more details on using hashes here.

File details

Details for the file smalldiffusion-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: smalldiffusion-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 18.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.26 {"installer":{"name":"uv","version":"0.9.26","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Arch Linux","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for smalldiffusion-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 04f9f8fd7d2bca70c202f16318b90bbb6873366b043885ba8b7de9dceecc4ee0
MD5 188cbf46bc99c71a66451eb2fe231356
BLAKE2b-256 118b533ae3a7ec18e72af3ca07111355ea1791e1586941cee760c428bf8564f3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page