Skip to main content

Torchmanager Implementation for Diffusion Model (v1.0)

Project description

Torchmanager Diffusion Models Plug-in

The torchmanager implementation for diffusion models.

Pre-requisites

Installation

  • PyPi: pip install --pre torchmanager-diffusion

DDPM Manager Usage

Train DDPM

Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.

import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses

# initialize dataset
dataset: data.Dataset = ...

# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...

# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...

# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)

# initialize callbacks
callback_list: list[callbacks.Callback] = ...

# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)

Evaluate DDPM

Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.

import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models

# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."

# initialize dataset
testing_dataset: data.Dataset = ...

# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity()  # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})

# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)

Customize Diffusion Algorithm

Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.

from diffusion import DiffusionManager

class CustomizedManager(DiffusionManager):
    def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
        ...

    def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager_diffusion-1.0.tar.gz (23.4 kB view details)

Uploaded Source

Built Distribution

torchmanager_diffusion-1.0-py3-none-any.whl (33.4 kB view details)

Uploaded Python 3

File details

Details for the file torchmanager_diffusion-1.0.tar.gz.

File metadata

  • Download URL: torchmanager_diffusion-1.0.tar.gz
  • Upload date:
  • Size: 23.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for torchmanager_diffusion-1.0.tar.gz
Algorithm Hash digest
SHA256 a8710bf997b32181cb4edc64437034988d4216a210b4e0d0d3d715062c4e12b6
MD5 5480c14e73f98046396d9ccca143a091
BLAKE2b-256 c786e169c478a4183dd29fd522720b506e70a48475fe23549e239e48c7ae464e

See more details on using hashes here.

File details

Details for the file torchmanager_diffusion-1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmanager_diffusion-1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b4e5ee2f0b64778566ab5e7641a5f426282deedb555d2b78c3f48d9c23727135
MD5 eebb168badbf1a76ccf87a8dbda4ebed
BLAKE2b-256 01ba6a9fa0e6ab4c0869b4eb4bbb7bba1d0d64fb002ae3205e89a61b9f619bd4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page