Rectified Flow in Pytorch
Project description
Rectified Flow - Pytorch
Implementation of rectified flow and some of its followup research / improvements in Pytorch
32 batch size, 11k steps oxford flowers
Install
$ pip install rectified-flow-pytorch
Usage
import torch
from rectified_flow_pytorch import RectifiedFlow, Unet
model = Unet(dim = 64)
rectified_flow = RectifiedFlow(model)
images = torch.randn(1, 3, 256, 256)
loss = rectified_flow(images)
loss.backward()
sampled = rectified_flow.sample()
assert sampled.shape[1:] == images.shape[1:]
For reflow as described in the paper
import torch
from rectified_flow_pytorch import RectifiedFlow, Reflow, Unet
model = Unet(dim = 64)
rectified_flow = RectifiedFlow(model)
images = torch.randn(1, 3, 256, 256)
loss = rectified_flow(images)
loss.backward()
# do the above for many real images
reflow = Reflow(rectified_flow)
reflow_loss = reflow()
reflow_loss.backward()
# then do the above in a loop many times for reflow - you can reflow multiple times by redefining Reflow(reflow.model) and looping again
sampled = reflow.sample()
assert sampled.shape[1:] == images.shape[1:]
With a Trainer
based on accelerate
import torch
from rectified_flow_pytorch import RectifiedFlow, ImageDataset, Unet, Trainer
model = Unet(dim = 64)
rectified_flow = RectifiedFlow(model)
img_dataset = ImageDataset(
folder = './path/to/your/images',
image_size = 256
)
trainer = Trainer(
rectified_flow,
dataset = img_dataset,
num_train_steps = 70_000,
results_folder = './results' # samples will be saved periodically to this folder
)
trainer()
Citations
@article{Liu2022FlowSA,
title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow},
author = {Xingchao Liu and Chengyue Gong and Qiang Liu},
journal = {ArXiv},
year = {2022},
volume = {abs/2209.03003},
url = {https://api.semanticscholar.org/CorpusID:252111177}
}
@article{Lee2024ImprovingTT,
title = {Improving the Training of Rectified Flows},
author = {Sangyun Lee and Zinan Lin and Giulia Fanti},
journal = {ArXiv},
year = {2024},
volume = {abs/2405.20320},
url = {https://api.semanticscholar.org/CorpusID:270123378}
}
@article{Esser2024ScalingRF,
title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
journal = {ArXiv},
year = {2024},
volume = {abs/2403.03206},
url = {https://api.semanticscholar.org/CorpusID:268247980}
}
@article{Li2024ImmiscibleDA,
title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment},
author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.12303},
url = {https://api.semanticscholar.org/CorpusID:270562607}
}
@article{Yang2024ConsistencyFM,
title = {Consistency Flow Matching: Defining Straight Flows with Velocity Consistency},
author = {Ling Yang and Zixiang Zhang and Zhilong Zhang and Xingchao Liu and Minkai Xu and Wentao Zhang and Chenlin Meng and Stefano Ermon and Bin Cui},
journal = {ArXiv},
year = {2024},
volume = {abs/2407.02398},
url = {https://api.semanticscholar.org/CorpusID:270878436}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for rectified_flow_pytorch-0.1.8.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31c0446fb79512dcc5c8521471a3531b8c89e256fa04481a5bb9cd6caae4ddb4 |
|
MD5 | 41bbe1245ba7cc25c7c3dee1f9cf71e8 |
|
BLAKE2b-256 | 20491859df002f25c578da6ccdc2fe0a24cf85dfdbe5a717679c4a3812323bfa |
Close
Hashes for rectified_flow_pytorch-0.1.8-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f709ebfac3d936feb6d2aa2dcba4f07a309530ede4afbd1a215767dedbb8cbf4 |
|
MD5 | 40bf83e55d0e8aa40d141c50e435008e |
|
BLAKE2b-256 | bc7416229add00ec8a241d5ad763c6462ed0cb8c2c3c056a1605062ff6159b0c |