Skip to main content

Diffusion Models Made Easy

Project description

Diffusion Models Made Easy

Diffusion Models Made Easy(dmme) is a collection of easy to understand diffusion model implementations in Pytorch.

Getting Started

Installation

pip install dmme

installing dmme in edit mode requires pip>=22.3, update pip by running pip install -U pip

Training

Train DDPM Using LightningCLI and wandb logger with mixed precision

python scripts/trainer.py fit --config configs/ddpm/cifar10.yaml

Train DDPM from python using pytorch-lightning

from pytorch_lightning import Trainer

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from dmme import LitDDPM, CIFAR10
from dmme.ddpm import UNet


def main():
    trainer = Trainer(
        logger=WandbLogger(project="CIFAR10 Image Generation", name="DDPM"),
        callbacks=[
            ModelCheckpoint(save_last=True),
            LearningRateMonitor(),
        ],
        gradient_clip_val=1.0,
        auto_select_gpus=True,
        accelerator="gpu",
        precision=16,
        max_steps=800_000,
    )

    ddpm = LitDDPM(
        decoder=UNet(in_channels=3),
        lr=2e-4,
        warmup=5000,
        imgsize=(3, 32, 32),
        timesteps=1000,
    )
    cifar10 = CIFAR10()

    trainer.fit(ddpm, cifar10)


if __name__ == "__main__":
    main()

or use the DDPMSampler class to train using pytorch

note: does not include gradient clipping, logging and checkpointing

from tqdm import tqdm

import torch
from torch.optim import Adam

from dmme import CIFAR10

from dmme.ddpm import UNet, DDPMSampler
from dmme.lr_scheduler import WarmupLR
from dmme.noise_schedules import linear_schedule


def train(timesteps=1000, lr=2e-4, warmup=5000, max_steps=800_000):
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

    model = UNet()
    beta = linear_schedule(timesteps=timesteps)
    sampler = DDPMSampler(model, timesteps=timesteps, beta=beta)
    sampler = sampler.to(device)

    cifar10 = CIFAR10()
    cifar10.prepare_data()
    cifar10.setup("fit")

    train_dataloader = cifar10.train_dataloader()

    optimizer = Adam(sampler.parameters(), lr=lr)
    lr_scheduler = WarmupLR(optimizer, warmup=warmup)

    steps = 0
    while steps < max_steps:
        prog_bar = tqdm(train_dataloader)
        for x_0, _ in prog_bar:
            x_0 = x_0.to(device)
            with torch.autocast("cuda" if device != "cpu" else "cpu"):
                loss = sampler.compute_loss(x_0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            steps += 1

            prog_bar.set_postfix({"loss": loss.cpu().item(), "steps": steps})

            if steps == max_steps:
                break


if __name__ == "__main__":
    train()

Supported Diffusion Models

  • DDPM
  • Score Based Models comming soon...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dmme-0.0.2.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

dmme-0.0.2-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file dmme-0.0.2.tar.gz.

File metadata

  • Download URL: dmme-0.0.2.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.15

File hashes

Hashes for dmme-0.0.2.tar.gz
Algorithm Hash digest
SHA256 7709736cf36afa89f2e3d492b2b08d4bab6792f4cb3f83a20a4c3f4fdad1d98d
MD5 ea364658c0d1ce8018c7d81772686ee0
BLAKE2b-256 d2c37d860d599b00970f293326ca0c6c0be7ef7b905808c71b8dccc3e4307d98

See more details on using hashes here.

File details

Details for the file dmme-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: dmme-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.15

File hashes

Hashes for dmme-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5fbf5fd6ebab98929eb55b16da993bc1c5fd7bfb5901cf57af304428dc7a30e2
MD5 f4d5e32f766b99d7337beecd1062ebef
BLAKE2b-256 fb2850ba61564d6623486cee53dde559370cdad14e6592ef375842aae296cf2d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page