Skip to main content

Score-Based Image-to-Image Brownian Bridge (v1.1)

Project description

Score-Based Image-to-Image Brownian Bridge

This is the official implementation for the paper Score-Based Image-to-Image Brownian Bridge.

Pre-requisites

Installation

  • PyPi: pip install sde-bbdm

Usage

The SDE-BBDM model can be trained in image space or latent space. The following examples show how to train the model in image space and latent space with SDEBBDMManager class.

Train SDE-BBDM in Image Space

  1. Initialize a UNet model with networks.build_unet function, create an optimizer and a loss function. The coeffient c_lambda can be set optionally. The time_steps is the number of diffusion steps.
import torch
from sde_bbdm import networks, nn
from torchmanager import losses

# load model
unet = networks.build(3, 3)
c_lambda: float = ...
time_steps: int = ...
model = nn.ABridgeModule(unet, time_steps, c_lambda=c_lambda)

# load optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = losses.MAE()
  1. Compile SDE-BBDM in Image Space To train a Score-Based BBDM model in image space, compile SDEBBDMManager with the UNet model, optimizer, loss function.
from diffusion import Manager

# compile manager
manager = Manager(model, optimizer=optimizer, loss_fn=loss_fn)
  1. Initialize dataset and callbacks
from torchmanager import callbacks, data

# load dataset
dataset: data.Dataset = ...

# initialize callbacks
callback_list: list[callbacks.Callback] = ...
  1. Train the model
# train the model
epochs: int = ...
trained_model = manager.fit(dataset, epochs=epochs, callbacks=callback_list)

Train SDE-BBDM in Latent Space

To train a Score-Based BBDM model in latent space, build ABridgeModule with pre-trained encoder and decoder loaded as torch.nn.Module.

encoder: torch.nn.Module = ...
decoder: torch.nn.Module = ...
model = nn.ABridgeModule(unet, time_steps, c_lambda=c_lambda, encoder=encoder, decoder=decoder)

Evaluating the model

  1. Create metrics Initialize metrics with metrics.Metric class and set the metrics to the manager.
from torchmanager import metrics

metric_fns: dict[str, metrics.Metric] = ...
manager.metrics = metric_fns
  1. Run evaluation To evaluate the model, use test method by set sampling_images as True. The method will sample images and return a dictionary containing the evaluation results.
model.test(dataset, sampling_images=True)

Evaluating the model using fast sampling

To evaluate the model using fast sampling, run test method by set sampling_images as True and set sampling steps as a list of integers with fast_sampling as True. The method will return a dictionary containing the evaluation results.

sampling_steps: list[int] = ...
model.fast_sampling_steps = sampling_steps
manager.test(dataset, sampling_images=True, fast_sampling=True, sampling_steps=sampling_steps)

Example Scripts Usage

This section describes how to use the example scripts to train and evaluate Score-Based Image-to-Image Brownian Bridge.

Dependencies for example scripts

All the required packages for example scripts can be installed with the following command:

pip install -r requirements.txt

Install package

To run examples, install the package using pypi first.

Training Script

Use train.py to train a Score-Based Image-to-Image Brownian Bridge model. The script supports training in image space and latent space. The following examples show how to train the model in image space and latent space using edge2shoes dataset.

# go to the exmamples folder
cd examples

# train in image space
python train.py \
    edge2shoes \
    <data_dir> \
    <output_model_path>

# train in latent space
python train.py \
    edge2shoes \
    <data_dir> \
    <output_model_path> \
    -vq <vqgan_model_path>

Use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.

Evaluation Script

Use eval.py and eval_miou.py to evaluate a Score-Based Image-to-Image Brownian Bridge model. The script supports evaluation of torchmanager checkpoints and pre-trained PyTorch model. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.

# go to the exmamples folder
cd examples

# evaluate torchmanager checkpoint
python eval.py \
    edge2shoes \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

# evaluate pre-trained PyTorch model
python eval.py \
    edge2shoes \
    <data_dir> \
    <model_path> \
    --fast_sampling \
    --t 1000 \\
    -vq <vqgan_model_path>

Use eval_miou.py to evaluate the model with mIoU metric for cityscapes. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model using fast sampling method.

# go to the exmamples folder
cd examples

# evaluate torchmanager checkpoint
python eval_miou.py \
    <deeplabv3_model_path> \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

# evaluate pre-trained PyTorch model
python eval_miou.py \
    <deeplabv3_model_path> \
    <data_dir> \
    <model_path> \
    --fast_sampling \
    --t 1000 \\
    -vq <vqgan_model_path>

Again, use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.

Generation Script

Use generate.py to generate images from a Score-Based Image-to-Image Brownian Bridge model. The script supports generation from checkpoints and pre-trained PyTorch model. The following examples show how to generate images from a checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.

python generate.py \
    edge2shoes \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

Pre-trained Models and Checkpoints

We used pre-trained VQGAN separated from the official LDM OpenImage checkpoint. We first exported the state dict of the VQGAN model in the checkpoints, then use convert_vqgan.py script to convert the state dict to our vqgan.VQGAN PyTorch model for easy loading. The following command shows how to convert the state dict to a PyTorch model.

python convert_vqgan.py \
    <state_dict_path> \
    <output_model_path>

We used the pre-trained deeplabv3 to evaluate the mIoU on Cityscapes from here. Again, we convert the checkpoints in state dict into deeplabv3.DeepLabV3 PyTorch model using convert_deeplabv3.py script for easy loading. The following command shows how to convert the state dict to a PyTorch model.

python convert_deeplabv3.py \
    <state_dict_path> \
    <output_model_path>

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sde_bbdm-1.1.tar.gz (28.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sde_bbdm-1.1-py3-none-any.whl (35.8 kB view details)

Uploaded Python 3

File details

Details for the file sde_bbdm-1.1.tar.gz.

File metadata

  • Download URL: sde_bbdm-1.1.tar.gz
  • Upload date:
  • Size: 28.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.18

File hashes

Hashes for sde_bbdm-1.1.tar.gz
Algorithm Hash digest
SHA256 5b444f05f1d87c33c613b8683b21224ebb35ab6fff3e966112a254aa0dca6693
MD5 d5d79211d63fd101dd21688e92063670
BLAKE2b-256 ef27de43cb3133e867459bec7f331cbc2fa0523b2f7bd45c0a52b656ce3b7834

See more details on using hashes here.

File details

Details for the file sde_bbdm-1.1-py3-none-any.whl.

File metadata

  • Download URL: sde_bbdm-1.1-py3-none-any.whl
  • Upload date:
  • Size: 35.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.18

File hashes

Hashes for sde_bbdm-1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e6caa072a89ffc721d7d204145dd16b9580857a087ee3a0d3e7397b658e9823d
MD5 c9eb3abe292791a754b29f53435667d9
BLAKE2b-256 d67a957034d7a44b4dd54355b212b84fbd79c0dd5c85058747f635623647e4c8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page