Score-Based Image-to-Image Brownian Bridge (v1.1)
Project description
Score-Based Image-to-Image Brownian Bridge
This is the official implementation for the paper Score-Based Image-to-Image Brownian Bridge.
- For framework usage, see Usage.
- For training and evaluating examples, see Example Scripts Usage.
- For pre-trained models and checkpoints, see Pre-trained Models and Checkpoints.
Pre-requisites
- Python >= 3.9
- PyTorch >= 2.0.1
- torchmanager-diffusion >= 1.0
Installation
- PyPi:
pip install sde-bbdm
Usage
The SDE-BBDM model can be trained in image space or latent space. The following examples show how to train the model in image space and latent space with SDEBBDMManager class.
Train SDE-BBDM in Image Space
- Initialize a UNet model with
networks.build_unetfunction, create an optimizer and a loss function. The coeffientc_lambdacan be set optionally. Thetime_stepsis the number of diffusion steps.
import torch
from sde_bbdm import networks, nn
from torchmanager import losses
# load model
unet = networks.build(3, 3)
c_lambda: float = ...
time_steps: int = ...
model = nn.ABridgeModule(unet, time_steps, c_lambda=c_lambda)
# load optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = losses.MAE()
- Compile SDE-BBDM in Image Space
To train a Score-Based BBDM model in image space, compile
SDEBBDMManagerwith the UNet model, optimizer, loss function.
from diffusion import Manager
# compile manager
manager = Manager(model, optimizer=optimizer, loss_fn=loss_fn)
- Initialize dataset and callbacks
from torchmanager import callbacks, data
# load dataset
dataset: data.Dataset = ...
# initialize callbacks
callback_list: list[callbacks.Callback] = ...
- Train the model
# train the model
epochs: int = ...
trained_model = manager.fit(dataset, epochs=epochs, callbacks=callback_list)
Train SDE-BBDM in Latent Space
To train a Score-Based BBDM model in latent space, build ABridgeModule with pre-trained encoder and decoder loaded as torch.nn.Module.
encoder: torch.nn.Module = ...
decoder: torch.nn.Module = ...
model = nn.ABridgeModule(unet, time_steps, c_lambda=c_lambda, encoder=encoder, decoder=decoder)
Evaluating the model
- Create metrics
Initialize metrics with
metrics.Metricclass and set the metrics to the manager.
from torchmanager import metrics
metric_fns: dict[str, metrics.Metric] = ...
manager.metrics = metric_fns
- Run evaluation
To evaluate the model, use
testmethod by setsampling_imagesasTrue. The method will sample images and return a dictionary containing the evaluation results.
model.test(dataset, sampling_images=True)
Evaluating the model using fast sampling
To evaluate the model using fast sampling, run test method by set sampling_images as True and set sampling steps as a list of integers with fast_sampling as True. The method will return a dictionary containing the evaluation results.
sampling_steps: list[int] = ...
model.fast_sampling_steps = sampling_steps
manager.test(dataset, sampling_images=True, fast_sampling=True, sampling_steps=sampling_steps)
Example Scripts Usage
This section describes how to use the example scripts to train and evaluate Score-Based Image-to-Image Brownian Bridge.
Dependencies for example scripts
All the required packages for example scripts can be installed with the following command:
pip install -r requirements.txt
Install package
To run examples, install the package using pypi first.
Training Script
Use train.py to train a Score-Based Image-to-Image Brownian Bridge model. The script supports training in image space and latent space. The following examples show how to train the model in image space and latent space using edge2shoes dataset.
# go to the exmamples folder
cd examples
# train in image space
python train.py \
edge2shoes \
<data_dir> \
<output_model_path>
# train in latent space
python train.py \
edge2shoes \
<data_dir> \
<output_model_path> \
-vq <vqgan_model_path>
Use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.
Evaluation Script
Use eval.py and eval_miou.py to evaluate a Score-Based Image-to-Image Brownian Bridge model. The script supports evaluation of torchmanager checkpoints and pre-trained PyTorch model. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.
# go to the exmamples folder
cd examples
# evaluate torchmanager checkpoint
python eval.py \
edge2shoes \
<data_dir> \
<checkpoint_path> \
--fast_sampling
# evaluate pre-trained PyTorch model
python eval.py \
edge2shoes \
<data_dir> \
<model_path> \
--fast_sampling \
--t 1000 \\
-vq <vqgan_model_path>
Use eval_miou.py to evaluate the model with mIoU metric for cityscapes. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model using fast sampling method.
# go to the exmamples folder
cd examples
# evaluate torchmanager checkpoint
python eval_miou.py \
<deeplabv3_model_path> \
<data_dir> \
<checkpoint_path> \
--fast_sampling
# evaluate pre-trained PyTorch model
python eval_miou.py \
<deeplabv3_model_path> \
<data_dir> \
<model_path> \
--fast_sampling \
--t 1000 \\
-vq <vqgan_model_path>
Again, use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.
Generation Script
Use generate.py to generate images from a Score-Based Image-to-Image Brownian Bridge model. The script supports generation from checkpoints and pre-trained PyTorch model. The following examples show how to generate images from a checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.
python generate.py \
edge2shoes \
<data_dir> \
<checkpoint_path> \
--fast_sampling
Pre-trained Models and Checkpoints
We used pre-trained VQGAN separated from the official LDM OpenImage checkpoint. We first exported the state dict of the VQGAN model in the checkpoints, then use convert_vqgan.py script to convert the state dict to our vqgan.VQGAN PyTorch model for easy loading. The following command shows how to convert the state dict to a PyTorch model.
python convert_vqgan.py \
<state_dict_path> \
<output_model_path>
We used the pre-trained deeplabv3 to evaluate the mIoU on Cityscapes from here. Again, we convert the checkpoints in state dict into deeplabv3.DeepLabV3 PyTorch model using convert_deeplabv3.py script for easy loading. The following command shows how to convert the state dict to a PyTorch model.
python convert_deeplabv3.py \
<state_dict_path> \
<output_model_path>
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sde_bbdm-1.1.tar.gz.
File metadata
- Download URL: sde_bbdm-1.1.tar.gz
- Upload date:
- Size: 28.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5b444f05f1d87c33c613b8683b21224ebb35ab6fff3e966112a254aa0dca6693
|
|
| MD5 |
d5d79211d63fd101dd21688e92063670
|
|
| BLAKE2b-256 |
ef27de43cb3133e867459bec7f331cbc2fa0523b2f7bd45c0a52b656ce3b7834
|
File details
Details for the file sde_bbdm-1.1-py3-none-any.whl.
File metadata
- Download URL: sde_bbdm-1.1-py3-none-any.whl
- Upload date:
- Size: 35.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e6caa072a89ffc721d7d204145dd16b9580857a087ee3a0d3e7397b658e9823d
|
|
| MD5 |
c9eb3abe292791a754b29f53435667d9
|
|
| BLAKE2b-256 |
d67a957034d7a44b4dd54355b212b84fbd79c0dd5c85058747f635623647e4c8
|