Skip to main content

Score-Based Image-to-Image Brownian Bridge (v1.0)

Project description

Score-Based Image-to-Image Brownian Bridge

This is the official implementation for the paper Score-Based Image-to-Image Brownian Bridge.

Pre-requisites

Usage

The SDE-BBDM model can be trained in image space or latent space. The following examples show how to train the model in image space and latent space with SDEBBDMManager class.

Train SDE-BBDM in Image Space

  1. Initialize a UNet model with networks.build_unet function, create an optimizer and a loss function. The coeffient c_lambda can be set optionally. The time_steps is the number of diffusion steps.
import torch
from sde_bbdm import networks
from torchmanager import losses

# load model
model = networks.build_unet(3, 3)
c_lambda: float = ...
time_steps: int = ...

# load optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = losses.MAE()
  1. Compile SDE-BBDM in Image Space To train a Score-Based BBDM model in image space, compile SDEBBDMManager with the UNet model, optimizer, loss function.
from sde_bbdm import SDEBBDMManager as Manager

# compile manager
time_steps: int = ...
manager = Manager(model, time_steps, optimizer=optimizer, loss_fn=loss_fn)
  1. Initialize dataset and callbacks
from torchmanager import callbacks, data

# load dataset
dataset: data.Dataset = ...

# initialize callbacks
callback_list: list[callbacks.Callback] = ...
  1. Train the model
# train the model
epochs: int = ...
trained_model = manager.fit(dataset, epochs=epochs, callbacks=callback_list)

Train SDE-BBDM in Latent Space

To train a Score-Based BBDM model in latent space, compile SDEBBDMManager with pre-trained encoder and decoder loaded as torch.nn.Module.

encoder: torch.nn.Module = ...
decoder: torch.nn.Module = ...

manager = Manager(model, time_steps, optimizer=optimizer, loss_fn=loss_fn, c_lambda=c_lambda, encoder=encoder, decoder=decoder)

Evaluating the model

  1. Create metrics Initialize metrics with metrics.Metric class and set the metrics to the manager.
from torchmanager import metrics

metric_fns: dict[str, metrics.Metric] = ...
manager.metrics = metric_fns
  1. Run evaluation To evaluate the model, use test method by set sampling_images as True. The method will sample images and return a dictionary containing the evaluation results.
model.test(dataset, sampling_images=True)

Evaluating the model using fast sampling

To evaluate the model using fast sampling, run test method by set sampling_images as True and set sampling steps as a list of integers with fast_sampling as True. The method will return a dictionary containing the evaluation results.

sampling_steps: list[int] = ...
model.test(dataset, sampling_images=True, fast_sampling=True, sampling_steps=sampling_steps)

Example Scripts Usage

This section describes how to use the example scripts to train and evaluate Score-Based Image-to-Image Brownian Bridge.

Dependencies for example scripts

All the required packages for example scripts can be installed with the following command:

pip install -r requirements.txt

Install package

To run examples, install the package first. The following command installs the package in editable mode.

pip install -e .

Training Script

Use train.py to train a Score-Based Image-to-Image Brownian Bridge model. The script supports training in image space and latent space. The following examples show how to train the model in image space and latent space using edge2shoes dataset.

# go to the exmamples folder
cd examples

# train in image space
python train.py \
    edge2shoes \
    <data_dir> \
    <output_model_path>

# train in latent space
python train.py \
    edge2shoes \
    <data_dir> \
    <output_model_path> \
    -vq <vqgan_model_path>

Use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.

Evaluation Script

Use eval.py and eval_miou.py to evaluate a Score-Based Image-to-Image Brownian Bridge model. The script supports evaluation of torchmanager checkpoints and pre-trained PyTorch model. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.

# go to the exmamples folder
cd examples

# evaluate torchmanager checkpoint
python eval.py \
    edge2shoes \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

# evaluate pre-trained PyTorch model
python eval.py \
    edge2shoes \
    <data_dir> \
    <model_path> \
    --fast_sampling \
    --t 1000 \\
    -vq <vqgan_model_path>

Use eval_miou.py to evaluate the model with mIoU metric for cityscapes. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model using fast sampling method.

# go to the exmamples folder
cd examples

# evaluate torchmanager checkpoint
python eval_miou.py \
    <deeplabv3_model_path> \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

# evaluate pre-trained PyTorch model
python eval_miou.py \
    <deeplabv3_model_path> \
    <data_dir> \
    <model_path> \
    --fast_sampling \
    --t 1000 \\
    -vq <vqgan_model_path>

Again, use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.

Generation Script

Use generate.py to generate images from a Score-Based Image-to-Image Brownian Bridge model. The script supports generation from checkpoints and pre-trained PyTorch model. The following examples show how to generate images from a checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.

python generate.py \
    edge2shoes \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

Pre-trained Models and Checkpoints

We used pre-trained VQGAN separated from the official LDM OpenImage checkpoint. We first exported the state dict of the VQGAN model in the checkpoints, then use convert_vqgan.py script to convert the state dict to our vqgan.VQGAN PyTorch model for easy loading. The following command shows how to convert the state dict to a PyTorch model.

python convert_vqgan.py \
    <state_dict_path> \
    <output_model_path>

We used the pre-trained deeplabv3 to evaluate the mIoU on Cityscapes from here. Again, we convert the checkpoints in state dict into deeplabv3.DeepLabV3 PyTorch model using convert_deeplabv3.py script for easy loading. The following command shows how to convert the state dict to a PyTorch model.

python convert_deeplabv3.py \
    <state_dict_path> \
    <output_model_path>

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sde_bbdm-1.0.tar.gz (25.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sde_bbdm-1.0-py3-none-any.whl (31.3 kB view details)

Uploaded Python 3

File details

Details for the file sde_bbdm-1.0.tar.gz.

File metadata

  • Download URL: sde_bbdm-1.0.tar.gz
  • Upload date:
  • Size: 25.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.18

File hashes

Hashes for sde_bbdm-1.0.tar.gz
Algorithm Hash digest
SHA256 ae2da22bc9250aee68e3d379c86c4db1ae4369e810294993d6dc7d052829ebf0
MD5 d648e91a52f572ef71db4b9cf19575f9
BLAKE2b-256 ebf6541ea0df7a74a69c21defbf54a4334ac6ae60f19c86e348bd7bc4cd67307

See more details on using hashes here.

File details

Details for the file sde_bbdm-1.0-py3-none-any.whl.

File metadata

  • Download URL: sde_bbdm-1.0-py3-none-any.whl
  • Upload date:
  • Size: 31.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.18

File hashes

Hashes for sde_bbdm-1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f51425b54730c3c877f3a0c1aae37776e8fe88a1cb106c8f2f32060842332c96
MD5 bef978d22b039120a5e43bb65fe1ceb5
BLAKE2b-256 5373ad42245de32991792730feb0186e83f56985a968199c3e0b4ca44f968d81

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page