Skip to main content

Torchmanager Implementation for Diffusion Model (v1.2.1)

Project description

Torchmanager Diffusion Models Plug-in

The torchmanager implementation for diffusion models.

Pre-requisites

Installation

  • PyPi: pip install torchmanager-diffusion

DDPM Manager Usage

Train DDPM

Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.

import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses

# initialize dataset
dataset: data.Dataset = ...

# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...

# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...

# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)

# initialize callbacks
callback_list: list[callbacks.Callback] = ...

# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)

Evaluate DDPM

Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.

import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models

# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."

# initialize dataset
testing_dataset: data.Dataset = ...

# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity()  # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})

# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)

Customize Diffusion Algorithm

Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.

from diffusion import DiffusionManager

class CustomizedManager(DiffusionManager):
    def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
        ...

    def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager_diffusion-1.2.1.tar.gz (32.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmanager_diffusion-1.2.1-py3-none-any.whl (50.9 kB view details)

Uploaded Python 3

File details

Details for the file torchmanager_diffusion-1.2.1.tar.gz.

File metadata

  • Download URL: torchmanager_diffusion-1.2.1.tar.gz
  • Upload date:
  • Size: 32.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.18

File hashes

Hashes for torchmanager_diffusion-1.2.1.tar.gz
Algorithm Hash digest
SHA256 ca5170a36a41c9b2cd294d5bac01e8095a7a7221877e87f9012aa06657db8661
MD5 b43db413c1f885f90886d8bc71d83989
BLAKE2b-256 9747f42efa74bbefa0fa6eb5118c8561c3889af608a52ea65efc169f8c87ac9b

See more details on using hashes here.

File details

Details for the file torchmanager_diffusion-1.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmanager_diffusion-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3d70c3e617e0c4196ee292a6581f221eb59cbae3f98545598681c9b8e94f43ea
MD5 d229209221fffdd78730dfae281d8b6b
BLAKE2b-256 d3d61e61a04d4f5d16e1ccf79318fecd4b4a13442d4b75b94da852c3f872fdfd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page