Skip to main content

Score-Based Image-to-Image Brownian Bridge (v1.1 Release Candidate 1)

Project description

Score-Based Image-to-Image Brownian Bridge

This is the official implementation for the paper Score-Based Image-to-Image Brownian Bridge.

Pre-requisites

Installation

  • PyPi: pip install sde-bbdm

Usage

The SDE-BBDM model can be trained in image space or latent space. The following examples show how to train the model in image space and latent space with SDEBBDMManager class.

Train SDE-BBDM in Image Space

  1. Initialize a UNet model with networks.build_unet function, create an optimizer and a loss function. The coeffient c_lambda can be set optionally. The time_steps is the number of diffusion steps.
import torch
from sde_bbdm import networks, nn
from torchmanager import losses

# load model
unet = networks.build(3, 3)
c_lambda: float = ...
time_steps: int = ...
model = nn.ABridgeModule(unet, time_steps, c_lambda=c_lambda)

# load optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = losses.MAE()
  1. Compile SDE-BBDM in Image Space To train a Score-Based BBDM model in image space, compile SDEBBDMManager with the UNet model, optimizer, loss function.
from diffusion import Manager

# compile manager
manager = Manager(model, optimizer=optimizer, loss_fn=loss_fn)
  1. Initialize dataset and callbacks
from torchmanager import callbacks, data

# load dataset
dataset: data.Dataset = ...

# initialize callbacks
callback_list: list[callbacks.Callback] = ...
  1. Train the model
# train the model
epochs: int = ...
trained_model = manager.fit(dataset, epochs=epochs, callbacks=callback_list)

Train SDE-BBDM in Latent Space

To train a Score-Based BBDM model in latent space, build ABridgeModule with pre-trained encoder and decoder loaded as torch.nn.Module.

encoder: torch.nn.Module = ...
decoder: torch.nn.Module = ...
model = nn.ABridgeModule(unet, time_steps, c_lambda=c_lambda, encoder=encoder, decoder=decoder)

Evaluating the model

  1. Create metrics Initialize metrics with metrics.Metric class and set the metrics to the manager.
from torchmanager import metrics

metric_fns: dict[str, metrics.Metric] = ...
manager.metrics = metric_fns
  1. Run evaluation To evaluate the model, use test method by set sampling_images as True. The method will sample images and return a dictionary containing the evaluation results.
model.test(dataset, sampling_images=True)

Evaluating the model using fast sampling

To evaluate the model using fast sampling, run test method by set sampling_images as True and set sampling steps as a list of integers with fast_sampling as True. The method will return a dictionary containing the evaluation results.

sampling_steps: list[int] = ...
model.fast_sampling_steps = sampling_steps
manager.test(dataset, sampling_images=True, fast_sampling=True, sampling_steps=sampling_steps)

Example Scripts Usage

This section describes how to use the example scripts to train and evaluate Score-Based Image-to-Image Brownian Bridge.

Dependencies for example scripts

All the required packages for example scripts can be installed with the following command:

pip install -r requirements.txt

Install package

To run examples, install the package using pypi first.

Training Script

Use train.py to train a Score-Based Image-to-Image Brownian Bridge model. The script supports training in image space and latent space. The following examples show how to train the model in image space and latent space using edge2shoes dataset.

# go to the exmamples folder
cd examples

# train in image space
python train.py \
    edge2shoes \
    <data_dir> \
    <output_model_path>

# train in latent space
python train.py \
    edge2shoes \
    <data_dir> \
    <output_model_path> \
    -vq <vqgan_model_path>

Use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.

Evaluation Script

Use eval.py and eval_miou.py to evaluate a Score-Based Image-to-Image Brownian Bridge model. The script supports evaluation of torchmanager checkpoints and pre-trained PyTorch model. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.

# go to the exmamples folder
cd examples

# evaluate torchmanager checkpoint
python eval.py \
    edge2shoes \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

# evaluate pre-trained PyTorch model
python eval.py \
    edge2shoes \
    <data_dir> \
    <model_path> \
    --fast_sampling \
    --t 1000 \\
    -vq <vqgan_model_path>

Use eval_miou.py to evaluate the model with mIoU metric for cityscapes. The following examples show how to evaluate a torchmanager checkpoint or a pre-trained PyTorch model using fast sampling method.

# go to the exmamples folder
cd examples

# evaluate torchmanager checkpoint
python eval_miou.py \
    <deeplabv3_model_path> \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

# evaluate pre-trained PyTorch model
python eval_miou.py \
    <deeplabv3_model_path> \
    <data_dir> \
    <model_path> \
    --fast_sampling \
    --t 1000 \\
    -vq <vqgan_model_path>

Again, use --show_verbose to display the training progress bar for each epoch. Set --device as cuda:<gpu_id> to use specific GPU. Use -use_multi_gpus without --device argument to use multiple GPUs.

Generation Script

Use generate.py to generate images from a Score-Based Image-to-Image Brownian Bridge model. The script supports generation from checkpoints and pre-trained PyTorch model. The following examples show how to generate images from a checkpoint or a pre-trained PyTorch model with edge2shoes dataset using fast sampling method.

python generate.py \
    edge2shoes \
    <data_dir> \
    <checkpoint_path> \
    --fast_sampling

Pre-trained Models and Checkpoints

We used pre-trained VQGAN separated from the official LDM OpenImage checkpoint. We first exported the state dict of the VQGAN model in the checkpoints, then use convert_vqgan.py script to convert the state dict to our vqgan.VQGAN PyTorch model for easy loading. The following command shows how to convert the state dict to a PyTorch model.

python convert_vqgan.py \
    <state_dict_path> \
    <output_model_path>

We used the pre-trained deeplabv3 to evaluate the mIoU on Cityscapes from here. Again, we convert the checkpoints in state dict into deeplabv3.DeepLabV3 PyTorch model using convert_deeplabv3.py script for easy loading. The following command shows how to convert the state dict to a PyTorch model.

python convert_deeplabv3.py \
    <state_dict_path> \
    <output_model_path>

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sde_bbdm-1.1rc1.tar.gz (29.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sde_bbdm-1.1rc1-py3-none-any.whl (36.1 kB view details)

Uploaded Python 3

File details

Details for the file sde_bbdm-1.1rc1.tar.gz.

File metadata

  • Download URL: sde_bbdm-1.1rc1.tar.gz
  • Upload date:
  • Size: 29.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for sde_bbdm-1.1rc1.tar.gz
Algorithm Hash digest
SHA256 d32ad6a4c884a23fa0eff3bad8192ad3c5be6d22dac781df468e1eaad3571650
MD5 6bb3eb312fcf2345089eb6f3dc3119d3
BLAKE2b-256 3fa31859391ee88107365cfc64d70b273944d43b39c564dd3952d4534cc7b1dd

See more details on using hashes here.

File details

Details for the file sde_bbdm-1.1rc1-py3-none-any.whl.

File metadata

  • Download URL: sde_bbdm-1.1rc1-py3-none-any.whl
  • Upload date:
  • Size: 36.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for sde_bbdm-1.1rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 ea0fb6ee6b4636dc4ea9746e208063580ef46c566ac5c754e7ddb32202930c9f
MD5 bd0c848ab5f316cc68f142ef57b8652d
BLAKE2b-256 f52812b5c8c4c762e38933e35fe043054e9622efef942d9d842f91e7ea2d3869

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page