Skip to main content

Karras et al. (2022) diffusion models for PyTorch

Project description

k-diffusion

An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch, with enhancements and additional features, such as improved sampling algorithms and transformer-based diffusion models.

Installation

k-diffusion can be installed via PyPI (pip install k-diffusion) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run pip install -e <path to repository>.

Training

To train models:

$ ./train.py --config CONFIG_FILE --name RUN_NAME

For instance, to train a model on MNIST:

$ ./train.py --config configs/config_mnist_transformer.json --name RUN_NAME

The configuration file allows you to specify the dataset type. Currently supported types are "imagefolder" (finds all images in that folder and its subfolders, recursively), "cifar10" (CIFAR-10), and "mnist" (MNIST). "huggingface" Hugging Face Datasets is also supported.

Multi-GPU and multi-node training is supported with Hugging Face Accelerate. You can configure Accelerate by running:

$ accelerate config

then running:

$ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME

Enhancements/additional features

  • k-diffusion has support for training transformer-based diffusion models (like DiT but improved).

  • k-diffusion supports a soft version of Min-SNR loss weighting for improved training at high resolutions with less hyperparameters than the loss weighting used in Karras et al. (2022).

  • k-diffusion has wrappers for v-diffusion-pytorch, OpenAI diffusion, and CompVis diffusion models allowing them to be used with its samplers and ODE/SDE.

  • k-diffusion implements DPM-Solver, which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. DPM-Solver++(2S) and (2M) are implemented now too for improved quality with low numbers of steps.

  • k-diffusion supports CLIP guided sampling from unconditional diffusion models (see sample_clip_guided.py).

  • k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.

  • k-diffusion can calculate, during training, the FID and KID vs the training set.

  • k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from An Empirical Model of Large-Batch Training, https://arxiv.org/abs/1812.06162).

To do

  • Latent diffusion

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

k-diffusion-0.1.1.post1.tar.gz (30.5 kB view details)

Uploaded Source

Built Distribution

k_diffusion-0.1.1.post1-py3-none-any.whl (33.8 kB view details)

Uploaded Python 3

File details

Details for the file k-diffusion-0.1.1.post1.tar.gz.

File metadata

  • Download URL: k-diffusion-0.1.1.post1.tar.gz
  • Upload date:
  • Size: 30.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.18

File hashes

Hashes for k-diffusion-0.1.1.post1.tar.gz
Algorithm Hash digest
SHA256 04b66b0e95202a60c4c178a482e6d31528b81a1c97011d6023ed3ad559de7166
MD5 ea0f0c1aeb44cfa63a0653ebd94404e7
BLAKE2b-256 2a7f7d3d8e7fb51c9edbb9fca4a8990e0f491ddcb547ea9f1c5b7e682882d451

See more details on using hashes here.

File details

Details for the file k_diffusion-0.1.1.post1-py3-none-any.whl.

File metadata

File hashes

Hashes for k_diffusion-0.1.1.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 b4bdcc74e35dd66df5d25433ea76b19edeaeaf2aad05ed97438be447cf4f5ee1
MD5 db3824e722d1cf85f3d1ee19783f058b
BLAKE2b-256 386efe37fc3b01b86836ee4d0b11a8097eaef5e388d9e2f76c4958e24ef15611

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page