Skip to main content

Some utility functions I frequently use with 🤗 diffusers.

Project description

cjm-diffusers-utils

Install

pip install cjm_diffusers_utils

How to use

import torch
from cjm_pytorch_utils.core import get_torch_device
device = get_torch_device()
dtype = torch.float16 if device == 'cuda' else torch.float16
device, dtype
('cuda', torch.float16)

pil_to_latent

from cjm_diffusers_utils.core import pil_to_latent
from PIL import Image
from diffusers import AutoencoderKL
model_name = "stabilityai/stable-diffusion-2-1"
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(device=device, dtype=dtype)
img_path = img_path = '../images/cat.jpg'
src_img = Image.open(img_path).convert('RGB')
print(f"Source Image Size: {src_img.size}")

img_latents = pil_to_latent(src_img, vae)
print(f"Latent Dimensions: {img_latents.shape}")
Source Image Size: (768, 512)
Latent Dimensions: torch.Size([1, 4, 64, 96])

latent_to_pil

from cjm_diffusers_utils.core import latent_to_pil
decoded_img = latent_to_pil(img_latents, vae)
print(f"Decoded Image Size: {decoded_img.size}")
Decoded Image Size: (768, 512)

text_to_emb

from cjm_diffusers_utils.core import text_to_emb
from transformers import CLIPTextModel, CLIPTokenizer
# Load the tokenizer for the specified model
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
# Load the text encoder for the specified model
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(device=device, dtype=dtype)
prompt = "A cat sitting on the floor."
text_emb = text_to_emb(prompt, tokenizer, text_encoder)
text_emb.shape
torch.Size([2, 77, 1024])

prepare_noise_scheduler

from cjm_diffusers_utils.core import prepare_noise_scheduler
from diffusers import DEISMultistepScheduler
noise_scheduler = DEISMultistepScheduler.from_pretrained(model_name, subfolder='scheduler')
print(f"Number of timesteps: {len(noise_scheduler.timesteps)}")
print(noise_scheduler.timesteps[:10])

noise_scheduler = prepare_noise_scheduler(noise_scheduler, 70, 1.0)
print(f"Number of timesteps: {len(noise_scheduler.timesteps)}")
print(noise_scheduler.timesteps[:10])
Number of timesteps: 1000
tensor([999., 998., 997., 996., 995., 994., 993., 992., 991., 990.])
Number of timesteps: 70
tensor([999, 985, 970, 956, 942, 928, 913, 899, 885, 871])

prepare_depth_mask

from cjm_diffusers_utils.core import prepare_depth_mask
depth_map_path = '../images/depth-cat.png'
depth_map = Image.open(depth_map_path)
print(f"Depth map size: {depth_map.size}")

depth_mask = prepare_depth_mask(depth_map).to(device=device, dtype=dtype)
depth_mask.shape, depth_mask.min(), depth_mask.max()
Depth map size: (768, 512)

(torch.Size([1, 1, 64, 96]),
 tensor(-1., device='cuda:0', dtype=torch.float16),
 tensor(1., device='cuda:0', dtype=torch.float16))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cjm-diffusers-utils-0.0.3.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

cjm_diffusers_utils-0.0.3-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file cjm-diffusers-utils-0.0.3.tar.gz.

File metadata

  • Download URL: cjm-diffusers-utils-0.0.3.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.15

File hashes

Hashes for cjm-diffusers-utils-0.0.3.tar.gz
Algorithm Hash digest
SHA256 ac9ccdc081720ffd717b21cf5ff3e992e16adacb3b62d6ad21b83e603dded7f0
MD5 ec8f1ebea9a5bc28e03f4dc284c2c101
BLAKE2b-256 5bc08a3f0e75172b23d60905f9a0e3ab11cb1e0ed648999974627aef30a7a54a

See more details on using hashes here.

File details

Details for the file cjm_diffusers_utils-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for cjm_diffusers_utils-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 116643f1294b45aa2a11bb7816be9010545795aa7c66237727f9761476bb61d9
MD5 6954f8a8af9e436140ada69431391360
BLAKE2b-256 02f168b5f173b88db8691821582b2ee3da04e8c6e20dab94adf74dbc0c236c47

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page