MegEngine implementation of Diffusion Models
Project description
MegDiffusion
MegEngine implementation of Diffusion Models (in early development).
Current maintainer: @MegChai
Usage
Infer with pre-trained models
Now users can use megengine.hub
to get pre-trained models directly:
import megengine
repo_info = "MegEngine/MegDiffusion:main"
megengine.hub.list(repo_info)
preatrained_model = "ddpm_cifar10_ema_converted"
megengine.hub.help(repo_info, preatrained_model)
model = megengine.hub.load(repo_info, preatrained_model, pretrained=True)
model.eval()
Note that using megengine.hub
will download the whole repository from it's host or using cache.
If you have downloaded or installed MegDiffusion, you can get pre-trained models from pretrain
module.
from megdiffusion.model import pretrain
model = pretrain.ddpm_cifar10_ema_converted(pretrained=True)
model.eval()
The sample script shows how to generate 64 CIFAR10-like images and make a grid of them:
python3 -m megdiffusion.pipeline.ddpm.sample
Train from scratch
-
Take DDPM CIFAR10 for example:
python3 -m megdiffusion.pipeline.ddpm.train \ --config ./configs/ddpm/cifar10.yaml
-
[Optional] Overwrite arguments:
python3 -m megdiffusion.pipeline.ddpm.train \ --config ./configs/ddpm/cifar10.yaml \ --logdir ./path/to/logdir \ --parallel --resume
See python3 -m megdiffusion.pipeline.ddpm.train --help
for more information.
For other options like batch_size
, we recommend modifying and backing up them in the yaml file.
If you want to sample with model trained by yourself (not the pre-trained model):
python3 -m megdiffusion.pipeline.ddpm.sample --nopretrain \
--logdir ./path/to/logdir \
--config ./configs/ddpm/cifar10.yaml # Coule be your customed file
Development
python3 -m pip install -r requirements.txt
python3 -m pip install -v -e .
Develop this project with a new branch locally, remember to add necessary test codes.
If finished, submit Pull Request to the main
branch then just wait for review.
Acknowledgment
The following open-sourced projects was referenced here:
- hojonathanho/diffusion: The official Tensorflow implementation of DDPM.
- w86763777/pytorch-ddpm: Unofficial PyTorch implementation of Denoising Diffusion Probabilistic Models.
- pesser/pytorch_diffusion: Unofficial PyTorch implementation of DDPM and provides converted torch checkpoints.
- openai/improved-diffusion: The official codebase for Improved Denoising Diffusion Probabilistic Models.
Thanks to people including @gaohuazuo, @xxr3376, @P2Oileen and other contributors for support in this project. The R&D platform and the resources required for the experiment are provided by MEGVII Inc. The deep learning framework used in this project is MegEngine -- a magic weapon.
Citations
@article{ho2020denoising,
title = {Denoising Diffusion Probabilistic Models},
author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
year = {2020},
eprint = {2006.11239},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{DBLP,
title = {Improved Denoising Diffusion Probabilistic Models},
author = {Alex Nichol and Prafulla Dhariwal},
year = {2021},
url = {https://arxiv.org/abs/2102.09672},
eprinttype = {arXiv},
eprint = {2102.09672},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for megdiffusion-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1534c195e9d41abedb99b77d5dbf2213fc85785174cd0dc36ab5bee609d360ea |
|
MD5 | a626a95504d2a1f99703dcfc3036fe7d |
|
BLAKE2b-256 | e624b1b25ec265f9141b0d6132f0eecd77dff5f9edb994f324784da7c4452389 |