Skip to main content

MegEngine implementation of Diffusion Models

Project description

MegDiffusion

MegEngine implementation of Diffusion Models (in early development).

Current maintainer: @MegChai

Usage

Infer with pre-trained models

Now users can use megengine.hub to get pre-trained models directly:

import megengine

repo_info = "MegEngine/MegDiffusion:main"
megengine.hub.list(repo_info)

preatrained_model = "ddpm_cifar10_ema_converted"
megengine.hub.help(repo_info, preatrained_model)

model = megengine.hub.load(repo_info, preatrained_model, pretrained=True)
model.eval()

Note that using megengine.hub will download the whole repository from it's host or using cache.

If you have downloaded or installed MegDiffusion, you can get pre-trained models from pretrain module.

from megdiffusion.model import pretrain

model = pretrain.ddpm_cifar10_ema_converted(pretrained=True)
model.eval()

The sample script shows how to generate 64 CIFAR10-like images and make a grid of them:

python3 -m megdiffusion.pipeline.ddpm.sample

Train from scratch

  • Take DDPM CIFAR10 for example:

    python3 -m megdiffusion.pipeline.ddpm.train \
        --config ./configs/ddpm/cifar10.yaml
    
  • [Optional] Overwrite arguments:

    python3 -m megdiffusion.pipeline.ddpm.train \
       --config ./configs/ddpm/cifar10.yaml \
       --logdir ./path/to/logdir \
       --parallel --resume
    

See python3 -m megdiffusion.pipeline.ddpm.train --help for more information. For other options like batch_size, we recommend modifying and backing up them in the yaml file.

If you want to sample with model trained by yourself (not the pre-trained model):

python3 -m megdiffusion.pipeline.ddpm.sample --nopretrain \
   --logdir ./path/to/logdir \
   --config ./configs/ddpm/cifar10.yaml  # Coule be your customed file

Development

python3 -m pip install -r requirements.txt
python3 -m pip install -v -e .

Develop this project with a new branch locally, remember to add necessary test codes. If finished, submit Pull Request to the main branch then just wait for review.

Acknowledgment

The following open-sourced projects was referenced here:

Thanks to people including @gaohuazuo, @xxr3376, @P2Oileen and other contributors for support in this project. The R&D platform and the resources required for the experiment are provided by MEGVII Inc. The deep learning framework used in this project is MegEngine -- a magic weapon.

Citations

@article{ho2020denoising,
    title   = {Denoising Diffusion Probabilistic Models},
    author  = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
    year    = {2020},
    eprint  = {2006.11239},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{DBLP,
  title     = {Improved Denoising Diffusion Probabilistic Models},
  author    = {Alex Nichol and Prafulla Dhariwal},
  year      = {2021},
  url       = {https://arxiv.org/abs/2102.09672},
  eprinttype = {arXiv},
  eprint    = {2102.09672},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

megdiffusion-0.0.2.tar.gz (24.4 kB view details)

Uploaded Source

Built Distribution

megdiffusion-0.0.2-py3-none-any.whl (29.4 kB view details)

Uploaded Python 3

File details

Details for the file megdiffusion-0.0.2.tar.gz.

File metadata

  • Download URL: megdiffusion-0.0.2.tar.gz
  • Upload date:
  • Size: 24.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for megdiffusion-0.0.2.tar.gz
Algorithm Hash digest
SHA256 47aa766113c9d54322ccf5e1623f2ac7209c810c9307495b02aa25d67960f94a
MD5 e37487704c843407905ff3bb9eaee19e
BLAKE2b-256 6d5e5534659124418b0cb08c29299ab9eb2927964e14029784baf27c74ce68c8

See more details on using hashes here.

File details

Details for the file megdiffusion-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for megdiffusion-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1534c195e9d41abedb99b77d5dbf2213fc85785174cd0dc36ab5bee609d360ea
MD5 a626a95504d2a1f99703dcfc3036fe7d
BLAKE2b-256 e624b1b25ec265f9141b0d6132f0eecd77dff5f9edb994f324784da7c4452389

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page