Skip to main content

Torchmanager Implementation for Diffusion Model (v1.2 Release Candidate 2)

Project description

Torchmanager Diffusion Models Plug-in

The torchmanager implementation for diffusion models.

Pre-requisites

Installation

  • PyPi: pip install torchmanager-diffusion

DDPM Manager Usage

Train DDPM

Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.

import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses

# initialize dataset
dataset: data.Dataset = ...

# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...

# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...

# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)

# initialize callbacks
callback_list: list[callbacks.Callback] = ...

# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)

Evaluate DDPM

Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.

import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models

# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."

# initialize dataset
testing_dataset: data.Dataset = ...

# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity()  # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})

# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)

Customize Diffusion Algorithm

Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.

from diffusion import DiffusionManager

class CustomizedManager(DiffusionManager):
    def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
        ...

    def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager_diffusion-1.2rc2.tar.gz (31.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchmanager_diffusion-1.2rc2-py3-none-any.whl (50.4 kB view details)

Uploaded Python 3

File details

Details for the file torchmanager_diffusion-1.2rc2.tar.gz.

File metadata

  • Download URL: torchmanager_diffusion-1.2rc2.tar.gz
  • Upload date:
  • Size: 31.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.18

File hashes

Hashes for torchmanager_diffusion-1.2rc2.tar.gz
Algorithm Hash digest
SHA256 327355c6a98aa489be14ae87e4b562041c51af55eed4e5f99aa4b158fd2d7c78
MD5 72aaecafe9d58bbccae09854b892e81b
BLAKE2b-256 41847c4a07ec548f959cd0c9185289c2aa2b0f6dfcedc40944f7800b1e0b8bbe

See more details on using hashes here.

File details

Details for the file torchmanager_diffusion-1.2rc2-py3-none-any.whl.

File metadata

File hashes

Hashes for torchmanager_diffusion-1.2rc2-py3-none-any.whl
Algorithm Hash digest
SHA256 11e9b39c2505f8ca10de9faa62d92151db3cf35dfc586979a15c2bb9f90120f4
MD5 e8120669f3ce25c9b4952f4bef996e35
BLAKE2b-256 174bd1336d6d2f751a328fbfab2a37b08a72d4b53c5fb6c8dd263c85b46b20ad

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page