Skip to main content

Torchmanager Implementation for Diffusion Model (v1.1 Release Candidate)

Project description

Torchmanager Diffusion Models Plug-in

The torchmanager implementation for diffusion models.

Pre-requisites

Installation

  • PyPi: pip install torchmanager-diffusion

DDPM Manager Usage

Train DDPM

Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.

import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses

# initialize dataset
dataset: data.Dataset = ...

# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...

# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...

# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)

# initialize callbacks
callback_list: list[callbacks.Callback] = ...

# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)

Evaluate DDPM

Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.

import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models

# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."

# initialize dataset
testing_dataset: data.Dataset = ...

# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity()  # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})

# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)

Customize Diffusion Algorithm

Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.

from diffusion import DiffusionManager

class CustomizedManager(DiffusionManager):
    def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
        ...

    def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager_diffusion-1.1.tar.gz (28.7 kB view details)

Uploaded Source

Built Distribution

torchmanager_diffusion-1.1-py3-none-any.whl (43.9 kB view details)

Uploaded Python 3

File details

Details for the file torchmanager_diffusion-1.1.tar.gz.

File metadata

  • Download URL: torchmanager_diffusion-1.1.tar.gz
  • Upload date:
  • Size: 28.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.18

File hashes

Hashes for torchmanager_diffusion-1.1.tar.gz
Algorithm Hash digest
SHA256 91f30e38d2dcf9f081b321dc02878fdc5a0c282b2c59514ede8aa43d72b814c1
MD5 b247029b13269e132194707c91494f93
BLAKE2b-256 33170d2a31888047532ee5a768dcde427970753811202ab4d6bb69a76be9ac1b

See more details on using hashes here.

File details

Details for the file torchmanager_diffusion-1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmanager_diffusion-1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9714dc826aba05cef27a88685479a419d91925924c5ff7c35683914c3139d847
MD5 dbd96867c8d346a1a08ba495e22f697a
BLAKE2b-256 0eb63ccc0e868f64eeb544bd6df70bcd660707ba0c82071fee591c51626dbb57

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page