Skip to main content

Torchmanager Implementation for Diffusion Model (v1.2.2)

Project description

Torchmanager Diffusion Models Plug-in

The torchmanager implementation for diffusion models.

Pre-requisites

Installation

  • PyPi: pip install torchmanager-diffusion

DDPM Manager Usage

Train DDPM

Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.

import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses

# initialize dataset
dataset: data.Dataset = ...

# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...

# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...

# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)

# initialize callbacks
callback_list: list[callbacks.Callback] = ...

# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)

Evaluate DDPM

Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.

import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models

# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."

# initialize dataset
testing_dataset: data.Dataset = ...

# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity()  # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})

# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)

Customize Diffusion Algorithm

Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.

from diffusion import DiffusionManager

class CustomizedManager(DiffusionManager):
    def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
        ...

    def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager_diffusion-1.2.2.tar.gz (31.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmanager_diffusion-1.2.2-py3-none-any.whl (50.7 kB view details)

Uploaded Python 3

File details

Details for the file torchmanager_diffusion-1.2.2.tar.gz.

File metadata

  • Download URL: torchmanager_diffusion-1.2.2.tar.gz
  • Upload date:
  • Size: 31.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.18

File hashes

Hashes for torchmanager_diffusion-1.2.2.tar.gz
Algorithm Hash digest
SHA256 134149f6b9200efbdb7b9d00f8b1a61768ae85320fa378a526b48dbb0e3ee675
MD5 06cf1cc8bf63fbd58fd4b6d50885d88f
BLAKE2b-256 407d5b731f3bc491884ec31259fa4b03a62aa020f3d4fc74ba5c5a4620ecc7a0

See more details on using hashes here.

File details

Details for the file torchmanager_diffusion-1.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmanager_diffusion-1.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 8fa84fc70a38887143922901b0ad1b559de25bd4759ee525c9896e376c301d97
MD5 322bc6c62af05c64fc7eaad1641ae8a0
BLAKE2b-256 da46ce6cc89db9f31560de649cf815f194b3da6b4de0018f3451f4f5a7308cc7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page