MegEngine implementation of Diffusion Models
Project description
MegDiffusion
MegEngine implementation of Diffusion Models (in early development).
Current maintainer: @MegChai
Usage
Infer with pre-trained models
Now users can use megengine.hub
to get pre-trained models directly:
import megengine
repo_info = "MegEngine/MegDiffusion:main"
megengine.hub.list(repo_info)
preatrained_model = "ddpm_cifar10_ema_converted"
megengine.hub.help(repo_info, preatrained_model)
model = megengine.hub.load(repo_info, preatrained_model, pretrained=True)
model.eval()
Note that using megengine.hub
will download the whole repository from it's host or using cache.
If you have downloaded or installed MegDiffusion, you can get pre-trained models from pretrain
module.
from megdiffusion.model import pretrain
model = pretrain.ddpm_cifar10_ema_converted(pretrained=True)
model.eval()
The sample script shows how to generate 64 CIFAR10-like images and make a grid of them:
python3 -m megdiffusion.pipeline.ddpm.sample
Train from scratch
-
Take DDPM CIFAR10 for example:
python3 -m megdiffusion.pipeline.ddpm.train \ --config ./configs/ddpm/cifar10.yaml
-
[Optional] Overwrite arguments:
python3 -m megdiffusion.pipeline.ddpm.train \ --config ./configs/ddpm/cifar10.yaml \ --logdir ./path/to/logdir \ --parallel --resume
See python3 -m megdiffusion.pipeline.ddpm.train --help
for more information.
For other options like batch_size
, we recommend modifying and backing up them in the yaml file.
If you want to sample with model trained by yourself (not the pre-trained model):
python3 -m megdiffusion.pipeline.ddpm.sample --nopretrain \
--logdir ./path/to/logdir \
--config ./configs/ddpm/cifar10.yaml # Coule be your customed file
Development
python3 -m pip install -r requirements.txt
python3 -m pip install -v -e .
Develop this project with a new branch locally, remember to add necessary test codes.
If finished, submit Pull Request to the main
branch then just wait for review.
Acknowledgment
The following open-sourced projects was referenced here:
- hojonathanho/diffusion: The official Tensorflow implementation of DDPM.
- w86763777/pytorch-ddpm: Unofficial PyTorch implementation of Denoising Diffusion Probabilistic Models.
- pesser/pytorch_diffusion: Unofficial PyTorch implementation of DDPM and provides converted torch checkpoints.
- openai/improved-diffusion: The official codebase for Improved Denoising Diffusion Probabilistic Models.
Thanks to people including @gaohuazuo, @xxr3376, @P2Oileen and other contributors for support in this project. The R&D platform and the resources required for the experiment are provided by MEGVII Inc. The deep learning framework used in this project is MegEngine -- a magic weapon.
Citations
@article{ho2020denoising,
title = {Denoising Diffusion Probabilistic Models},
author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
year = {2020},
eprint = {2006.11239},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{DBLP,
title = {Improved Denoising Diffusion Probabilistic Models},
author = {Alex Nichol and Prafulla Dhariwal},
year = {2021},
url = {https://arxiv.org/abs/2102.09672},
eprinttype = {arXiv},
eprint = {2102.09672},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file megdiffusion-0.0.2.tar.gz
.
File metadata
- Download URL: megdiffusion-0.0.2.tar.gz
- Upload date:
- Size: 24.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 47aa766113c9d54322ccf5e1623f2ac7209c810c9307495b02aa25d67960f94a |
|
MD5 | e37487704c843407905ff3bb9eaee19e |
|
BLAKE2b-256 | 6d5e5534659124418b0cb08c29299ab9eb2927964e14029784baf27c74ce68c8 |
File details
Details for the file megdiffusion-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: megdiffusion-0.0.2-py3-none-any.whl
- Upload date:
- Size: 29.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1534c195e9d41abedb99b77d5dbf2213fc85785174cd0dc36ab5bee609d360ea |
|
MD5 | a626a95504d2a1f99703dcfc3036fe7d |
|
BLAKE2b-256 | e624b1b25ec265f9141b0d6132f0eecd77dff5f9edb994f324784da7c4452389 |