Skip to main content

Diffusion Models Made Easy

Project description

Diffusion Models Made Easy

Diffusion Models Made Easy(dmme) is a collection of easy to understand diffusion model implementations in Pytorch.

Getting Started

Documentation is available at https://diffusion-models-made-easy.readthedocs.io/en/latest/

Installation

Install from pip

pip install dmme

installing dmme in edit mode requires pip>=22.3, update pip by running pip install -U pip

Install for customization or development

pip install -e ".[dev]"

Install dependencies for testing

pip install dmme[tests]

Install dependencies for docs

pip install dmme[docs]

Train Diffusion Models

Train DDPM Using LightningCLI and wandb logger with mixed precision

dmme.trainer fit --config configs/ddpm/cifar10.yaml

Train DDPM from python using pytorch-lightning

from pytorch_lightning import Trainer

from pytorch_lightning.loggers import WandbLogger

from dmme.ddpm import LitDDPM, UNet
from dmme.data import CIFAR10

from dmme.callbacks import GenerateImage


def main():
    trainer = Trainer(
        logger=WandbLogger(project="CIFAR10_Image_Generation", name="DDPM"),
        callbacks=GenerateImage((3, 32, 32), timesteps=1000),
        gradient_clip_val=1.0,
        auto_select_gpus=True,
        accelerator="gpu",
        precision=16,
        max_steps=800_000,
    )

    ddpm = LitDDPM(
        UNet(in_channels=3),
        lr=2e-4,
        warmup=5000,
        imgsize=(3, 32, 32),
        timesteps=1000,
        decay=0.9999,
    )
    cifar10 = CIFAR10()

    trainer.fit(ddpm, cifar10)


if __name__ == "__main__":
    main()

or use the DDPMSampler class to train using pytorch

note: does not include gradient clipping, logging and checkpointing

from tqdm import tqdm

import torch
from torch.optim import Adam
import torch.nn.functional as F

from dmme.data import CIFAR10

from dmme.ddpm import DDPM, UNet
from dmme.lr_scheduler import WarmupLR

from dmme.common import gaussian_like, uniform_int


def train(timesteps=1000, lr=2e-4, clip_val=1.0, warmup=5000, max_steps=800_000):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

    model = UNet()
    model = model.to(device)

    ddpm = DDPM(timesteps=timesteps)
    ddpm = ddpm.to(device)

    cifar10 = CIFAR10()
    cifar10.prepare_data()
    cifar10.setup("fit")

    train_dataloader = cifar10.train_dataloader()

    optimizer = Adam(model.parameters(), lr=lr)
    lr_scheduler = WarmupLR(optimizer, warmup=warmup)

    steps = 0
    while steps < max_steps:
        prog_bar = tqdm(train_dataloader)
        for x_0, _ in prog_bar:
            x_0 = x_0.to(device)

            batch_size: int = x_0.size(0)
            t = uniform_int(0, timesteps, batch_size, device=x_0.device)

            noise = gaussian_like(x_0)

            with torch.autocast("cuda" if device != "cpu" else "cpu"):
                x_t = ddpm.forward_process(x_0, t, noise)

                noise_estimate = model(x_t, t)

                loss = F.mse_loss(noise, noise_estimate)

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_val)

            optimizer.step()
            lr_scheduler.step()

            steps += 1

            prog_bar.set_postfix({"loss": loss, "steps": steps})

            if steps == max_steps:
                break


if __name__ == "__main__":
    train()

Supported Diffusion Models

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dmme-0.3.1.tar.gz (17.7 kB view details)

Uploaded Source

Built Distribution

dmme-0.3.1-py3-none-any.whl (21.2 kB view details)

Uploaded Python 3

File details

Details for the file dmme-0.3.1.tar.gz.

File metadata

  • Download URL: dmme-0.3.1.tar.gz
  • Upload date:
  • Size: 17.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for dmme-0.3.1.tar.gz
Algorithm Hash digest
SHA256 13925b9c2d42a21d72e182b3baa6342e1a976aefba213b7888213eaecf1de566
MD5 745b686c1ae32db839e79ffdce2ca49d
BLAKE2b-256 3c5cc2161a5985140039f29d8e49498984e0453c7dae695b29205b13bb176a5a

See more details on using hashes here.

File details

Details for the file dmme-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: dmme-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 21.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for dmme-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c2e0468fb160340aa45bc4c3cc2d4994ada58df6dd377c0591a1617299e647db
MD5 a403d676abb6ac939a00cbc34fadc823
BLAKE2b-256 3ec68ae44cbe565af243b8875c77970c58af5633f4b88b2f081ed9decf99f0df

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page