Skip to main content

Karras et al. (2022) diffusion models for PyTorch

Project description

k-diffusion

An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch. The patching method in Improving Diffusion Model Efficiency Through Patching is implemented as well.

Installation

k-diffusion can be installed via PyPI (pip install k-diffusion) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run pip install -e <path to repository>.

Training:

To train models:

$ ./train.py --config CONFIG_FILE --name RUN_NAME

For instance, to train a model on MNIST:

$ ./train.py --config configs/config_mnist.json --name RUN_NAME

The configuration file allows you to specify the dataset type. Currently supported types are "imagefolder" (finds all images in that folder and its subfolders, recursively), "cifar10" (CIFAR-10), and "mnist" (MNIST). "huggingface" Hugging Face Datasets is also supported.

Multi-GPU and multi-node training is supported with Hugging Face Accelerate. You can configure Accelerate by running:

$ accelerate config

on all nodes, then running:

$ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME

on all nodes.

Enhancements/additional features:

  • k-diffusion supports an experimental model output type, an isotropic Gaussian, which seems to have a lower gradient noise scale and to train faster than Karras et al. (2022) diffusion models.

  • k-diffusion has wrappers for v-diffusion-pytorch, OpenAI diffusion, and CompVis diffusion models allowing them to be used with its samplers and ODE/SDE.

  • k-diffusion models support progressive growing.

  • k-diffusion implements DPM-Solver, which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. DPM-Solver++(2S) and (2M) are implemented now too for improved quality with low numbers of steps.

  • k-diffusion supports CLIP guided sampling from unconditional diffusion models (see sample_clip_guided.py).

  • k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.

  • k-diffusion can calculate, during training, the FID and KID vs the training set.

  • k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from An Empirical Model of Large-Batch Training, https://arxiv.org/abs/1812.06162).

To do:

  • Anything except unconditional image diffusion models

  • Latent diffusion

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

k-diffusion-0.0.15.tar.gz (23.9 kB view details)

Uploaded Source

Built Distribution

k_diffusion-0.0.15-py3-none-any.whl (25.7 kB view details)

Uploaded Python 3

File details

Details for the file k-diffusion-0.0.15.tar.gz.

File metadata

  • Download URL: k-diffusion-0.0.15.tar.gz
  • Upload date:
  • Size: 23.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for k-diffusion-0.0.15.tar.gz
Algorithm Hash digest
SHA256 57e3cce038402a30d649039ea9fefe15b412103ea7048f2958d23983addbd345
MD5 d30016fbeec74573196ad1c1a734ceff
BLAKE2b-256 7c0bcbb9cef09b8dedddcdc22fc214b340e8fbc5cecf1738acc72f986284c2c8

See more details on using hashes here.

File details

Details for the file k_diffusion-0.0.15-py3-none-any.whl.

File metadata

  • Download URL: k_diffusion-0.0.15-py3-none-any.whl
  • Upload date:
  • Size: 25.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for k_diffusion-0.0.15-py3-none-any.whl
Algorithm Hash digest
SHA256 8e392c98ea881cf19a1e2ebec1569fa12c7f31d23c325b3d3bb185f9f10560fb
MD5 d2c8d1e02988558a0a6c3c62355b463d
BLAKE2b-256 0c06a6b35223ee1939c8fdade6d7eaa2243eafd453c963081925c28cf86624fb

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page