Skip to main content

Diffusion Models Made Easy

Project description

Diffusion Models Made Easy

Diffusion Models Made Easy(dmme) is a collection of easy to understand diffusion model implementations in Pytorch.

Getting Started

Documentation is available at https://diffusion-models-made-easy.readthedocs.io/en/latest/

Installation

Install from pip

pip install dmme

installing dmme in edit mode requires pip>=22.3, update pip by running pip install -U pip

Install for customization or development

pip install -e ".[dev]"

Install dependencies for testing

pip install dmme[tests]

Install dependencies for docs

pip install dmme[docs]

Training

Train DDPM Using LightningCLI and wandb logger with mixed precision

python scripts/trainer.py fit --config configs/ddpm/cifar10.yaml

Train DDPM from python using pytorch-lightning

from pytorch_lightning import Trainer

from pytorch_lightning.loggers import WandbLogger

from dmme import LitDDPM, DDPMSampler, CIFAR10
from dmme.ddpm import UNet


def main():
    trainer = Trainer(
        logger=WandbLogger(project="CIFAR10 Image Generation", name="DDPM"),
        gradient_clip_val=1.0,
        auto_select_gpus=True,
        accelerator="gpu",
        precision=16,
        max_steps=800_000,
    )

    ddpm = LitDDPM(
        DDPMSampler(UNet(in_channels=3), timesteps=1000),
        lr=2e-4,
        warmup=5000,
        imgsize=(3, 32, 32),
        timesteps=1000,
        decay=0.9999,
    )
    cifar10 = CIFAR10()

    trainer.fit(ddpm, cifar10)


if __name__ == "__main__":
    main()

or use the DDPMSampler class to train using pytorch

note: does not include gradient clipping, logging and checkpointing

from tqdm import tqdm

import torch
from torch.optim import Adam

from dmme import CIFAR10

from dmme.ddpm import UNet, DDPMSampler
from dmme.lr_scheduler import WarmupLR
from dmme.noise_schedules import linear_schedule


def train(timesteps=1000, lr=2e-4, clip_val=1.0, warmup=5000, max_steps=800_000):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

    model = UNet()
    beta = linear_schedule(timesteps=timesteps)
    sampler = DDPMSampler(model, timesteps=timesteps, beta=beta)
    sampler = sampler.to(device)

    cifar10 = CIFAR10()
    cifar10.prepare_data()
    cifar10.setup("fit")

    train_dataloader = cifar10.train_dataloader()

    optimizer = Adam(sampler.parameters(), lr=lr)
    lr_scheduler = WarmupLR(optimizer, warmup=warmup)

    steps = 0
    while steps < max_steps:
        prog_bar = tqdm(train_dataloader)
        for x_0, _ in prog_bar:
            x_0 = x_0.to(device)
            with torch.autocast("cuda" if device != "cpu" else "cpu"):
                loss = sampler.compute_loss(x_0)

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm(sampler.parameters(), clip_val)

            optimizer.step()
            lr_scheduler.step()

            steps += 1

            prog_bar.set_postfix({"loss": loss, "steps": steps})

            if steps == max_steps:
                break


if __name__ == "__main__":
    train()

Supported Diffusion Models

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dmme-0.1.1.tar.gz (18.2 kB view details)

Uploaded Source

Built Distribution

dmme-0.1.1-py3-none-any.whl (22.1 kB view details)

Uploaded Python 3

File details

Details for the file dmme-0.1.1.tar.gz.

File metadata

  • Download URL: dmme-0.1.1.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for dmme-0.1.1.tar.gz
Algorithm Hash digest
SHA256 60a18e3660b507fd21dc6dd6e796f88ab21da11f44f71bbefae6a2315697dcb9
MD5 0486f27f2cce5c21f8a8487e5747ece2
BLAKE2b-256 d7d0498e96da937b6ba0c15d5d9fc3e912a26d0a681d55ac08b6f36f2bd54512

See more details on using hashes here.

File details

Details for the file dmme-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: dmme-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 22.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for dmme-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4595e3884716013986a0a7bcfc58502ff5aa82cd91f40d04fd32e46e78e9bec5
MD5 19c0b8e71caf3a581978a62ed02c1d45
BLAKE2b-256 a841df0642ca5be902a0e2a63c34f8bbbdefd6d45a21bf49a115e08d44306ee5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page