Skip to main content

Torchmanager Implementation for Diffusion Model (v1.2 Release Candidate)

Project description

Torchmanager Diffusion Models Plug-in

The torchmanager implementation for diffusion models.

Pre-requisites

Installation

  • PyPi: pip install torchmanager-diffusion

DDPM Manager Usage

Train DDPM

Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.

import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses

# initialize dataset
dataset: data.Dataset = ...

# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...

# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...

# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)

# initialize callbacks
callback_list: list[callbacks.Callback] = ...

# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)

Evaluate DDPM

Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.

import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models

# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."

# initialize dataset
testing_dataset: data.Dataset = ...

# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity()  # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})

# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)

Customize Diffusion Algorithm

Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.

from diffusion import DiffusionManager

class CustomizedManager(DiffusionManager):
    def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
        ...

    def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager_diffusion-1.2rc1.tar.gz (31.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmanager_diffusion-1.2rc1-py3-none-any.whl (50.6 kB view details)

Uploaded Python 3

File details

Details for the file torchmanager_diffusion-1.2rc1.tar.gz.

File metadata

  • Download URL: torchmanager_diffusion-1.2rc1.tar.gz
  • Upload date:
  • Size: 31.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for torchmanager_diffusion-1.2rc1.tar.gz
Algorithm Hash digest
SHA256 ec944a06e473a880e403108d2e23ab6e4ef198dc664e30c85e2c8f47188a9098
MD5 2af6b468852c7d4f90568b9d78b23f41
BLAKE2b-256 1b009bbe6bdf4a84653f75b9c7748f94f8967981e79eaf313ebef332f3b21efa

See more details on using hashes here.

File details

Details for the file torchmanager_diffusion-1.2rc1-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmanager_diffusion-1.2rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 e3ad804d8b10cbf0a2af66407185c79e4c9e21dfd75fdca99b49cda51c4a1998
MD5 fe709dfa5e5f1ef83423d05100b022f5
BLAKE2b-256 ba9fcb4faf35144b75aa9e4e123e881ae373c15a6afa9550f08b8b125e72d8e2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page