Autoregressive Diffusion - Pytorch
Project description
Autoregressive Diffusion - Pytorch
Implementation of the architecture behind Autoregressive Image Generation without Vector Quantization in Pytorch
oxford flowers at 59k steps
Install
$ pip install autoregressive-diffusion-pytorch
Usage
import torch
from autoregressive_diffusion_pytorch import AutoregressiveDiffusion
model = AutoregressiveDiffusion(
dim = 1024,
max_seq_len = 32,
depth = 8,
mlp_depth = 3,
mlp_width = 1024
)
seq = torch.randn(3, 32, 512)
loss = model(seq)
loss.backward()
sampled = model.sample(batch_size = 3)
assert sampled.shape == seq.shape
For images treated as a sequence of tokens (as in paper)
import torch
from autoregressive_diffusion_pytorch import ImageAutoregressiveDiffusion
model = ImageAutoregressiveDiffusion(
model = dict(
dim = 1024,
depth = 12,
heads = 12,
),
image_size = 64,
patch_size = 8
)
images = torch.randn(3, 3, 64, 64)
loss = model(images)
loss.backward()
sampled = model.sample(batch_size = 3)
assert sampled.shape == images.shape
Citations
@article{Li2024AutoregressiveIG,
title = {Autoregressive Image Generation without Vector Quantization},
author = {Tianhong Li and Yonglong Tian and He Li and Mingyang Deng and Kaiming He},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.11838},
url = {https://api.semanticscholar.org/CorpusID:270560593}
}
@article{Wu2023ARDiffusionAD,
title = {AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation},
author = {Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Haitao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.09515},
url = {https://api.semanticscholar.org/CorpusID:258714669}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for autoregressive_diffusion_pytorch-0.1.5.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6de6ae9319c57b77695a79d7290aa157efbded59a2bd1bc8134805880fd4be1e |
|
MD5 | e256946eecf0992ef3ecaa73797ca307 |
|
BLAKE2b-256 | fd736a27a6d069da00dbbb026bf298269532224c15053d227989fd08150aaa41 |
Close
Hashes for autoregressive_diffusion_pytorch-0.1.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 25209d4f0c25ceccc8b09e255a9889455f40ea6796292bb06587fd80b7fdb4ed |
|
MD5 | fd69192b877b1179b262d77f1db3c6a1 |
|
BLAKE2b-256 | 31e70bcc50796bc3c15fbc5ffb3da12b3504143d33c7e5c50a04673757539be8 |