Torchmanager Implementation for Diffusion Model (v1.2 Release Candidate 2)
Project description
Torchmanager Diffusion Models Plug-in
The torchmanager implementation for diffusion models.
Pre-requisites
- Python >= 3.9
- SciPy >= 1.11.4
- PyTorch >= 2.0.1
- LPIPS
- torchmanager >= 1.2
- einops >= 0.6.1
Installation
- PyPi:
pip install torchmanager-diffusion
DDPM Manager Usage
Train DDPM
Direct compile DDPMManager with a model, a beta space, and a number of time steps. Then, use fit method to train the model.
import diffusion
from diffusion import DDPMManager
from torchmanager import callbacks, data, losses
# initialize dataset
dataset: data.Dataset = ...
# initialize model, beta_space, and time_steps
model: torch.nn.Module = ...
beta_space: diffusion.scheduling.BetaSpace = ...
time_steps: int = ...
# initialize optimizer and loss function
optimizer: torch.optim.Optimizer = ...
loss_fn: losses.Loss = ...
# compile the ddpm manager
manager = DDPMManager(model, beta_space, time_steps, optimizer=optimizer, loss_fn=loss_fn)
# initialize callbacks
callback_list: list[callbacks.Callback] = ...
# train the model
trained_model = manager.fit(dataset, epochs=..., callbacks=callback_list)
Evaluate DDPM
Add necessary metrics and use test method with sampling_images as True to evaluate the trained model.
import torch
from diffusion import DDPMManager
from torchmanager import data, metrics
from torchvision import models
# load manager from checkpoints
manager = DDPMManager.from_checkpoint(...)
assert isinstance(manager, DDPMManager), "manager is not a DDPMManager."
# initialize dataset
testing_dataset: data.Dataset = ...
# add neccessary metrics
inception = models.inception_v3(pretrained=True)
inception.fc = torch.nn.Identity() # type: ignore
inception.eval()
fid = metrics.FID(inception)
manager.metrics.update({"FID": fid})
# evaluate the model
summary = manager.test(testing_dataset, sampling_images=True)
Customize Diffusion Algorithm
Inherit DiffusionManager and implement abstract methods forward_diffusion and sampling_step to customize the diffusion algorithm.
from diffusion import DiffusionManager
class CustomizedManager(DiffusionManager):
def forward_diffusion(self, data: Any, condition: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None) -> tuple[Any, torch.Tensor]:
...
def sampling_step(self, data: DiffusionData, i: int, /, *, return_noise: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchmanager_diffusion-1.2rc2.tar.gz.
File metadata
- Download URL: torchmanager_diffusion-1.2rc2.tar.gz
- Upload date:
- Size: 31.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
327355c6a98aa489be14ae87e4b562041c51af55eed4e5f99aa4b158fd2d7c78
|
|
| MD5 |
72aaecafe9d58bbccae09854b892e81b
|
|
| BLAKE2b-256 |
41847c4a07ec548f959cd0c9185289c2aa2b0f6dfcedc40944f7800b1e0b8bbe
|
File details
Details for the file torchmanager_diffusion-1.2rc2-py3-none-any.whl.
File metadata
- Download URL: torchmanager_diffusion-1.2rc2-py3-none-any.whl
- Upload date:
- Size: 50.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
11e9b39c2505f8ca10de9faa62d92151db3cf35dfc586979a15c2bb9f90120f4
|
|
| MD5 |
e8120669f3ce25c9b4952f4bef996e35
|
|
| BLAKE2b-256 |
174bd1336d6d2f751a328fbfab2a37b08a72d4b53c5fb6c8dd263c85b46b20ad
|