Torchmanager Implementation for Diffusion Model (v1.0)
Project description
Torchmanager Diffusion Models Plug-in
The torchmanager implementation for diffusion models.
Pre-requisites
- Python >= 3.9
- SciPy >= 1.11.4
- PyTorch >= 2.0.1
- LPIPS
- torchmanager >= 1.2
- einops >= 0.6.1
Installation
- PyPi:
pip install --pre torchmanager-diffusion
DDPM Manager Usage
Train DDPM
Direct compile DDPMManager
with a model, a beta space, and a number of time steps. Then, use fit
method to train the model.
import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses
# initialize dataset
dataset: data.Dataset = ...
# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...
# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...
# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)
# initialize callbacks
callback_list: list[callbacks.Callback] = ...
# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)
Evaluate DDPM
Add necessary metrics and use test
method with sampling_images
as True
to evaluate the trained model.
import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models
# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."
# initialize dataset
testing_dataset: data.Dataset = ...
# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity() # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})
# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)
Customize Diffusion Algorithm
Inherit DiffusionManager
and implement abstract methods forward_diffusion
and sampling_step
to customize the diffusion algorithm.
from diffusion import DiffusionManager
class CustomizedManager(DiffusionManager):
def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
...
def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
...
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchmanager_diffusion-1.0.tar.gz
(23.4 kB
view hashes)
Built Distribution
Close
Hashes for torchmanager_diffusion-1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | a8710bf997b32181cb4edc64437034988d4216a210b4e0d0d3d715062c4e12b6 |
|
MD5 | 5480c14e73f98046396d9ccca143a091 |
|
BLAKE2b-256 | c786e169c478a4183dd29fd522720b506e70a48475fe23549e239e48c7ae464e |
Close
Hashes for torchmanager_diffusion-1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b4e5ee2f0b64778566ab5e7641a5f426282deedb555d2b78c3f48d9c23727135 |
|
MD5 | eebb168badbf1a76ccf87a8dbda4ebed |
|
BLAKE2b-256 | 01ba6a9fa0e6ab4c0869b4eb4bbb7bba1d0d64fb002ae3205e89a61b9f619bd4 |