Minimal Imagen text-to-image model implementation.
Project description
MinImagen
A Minimal implementation of the Imagen text-to-image model.
For a tutorial on building this model, see here.
Given a caption of an image, Imagen will generate an image that reflects the caption. The model is a simple cascading diffusion model, using a T5 text encoder to encode the captions which conditions a base image generator, and then a sequence of super-resolution models.
In particular, two notable contributions are the developments of:
- Noise Conditioning Augmentation, which noises low-resolution conditioning images in the super-resolution models, and
- Dynamic Thresholding which helps prevent image saturation at high classifier-free guidance weights.
See How Imagen Actually Works for a detailed explanation of Imagen's operating principles.
Attribution Note
This implementation is largely based on Phil Wang's Imagen implementation.
Installation
$ pip install minimagen
In order to use the training.py file, you will also need to install datasets and nonechucks:
$ pip install nonechucks
Note that MinImagen requires Python3.9 or higher
Documentation
Documentation can be found here
Usage
A minimal usage:
import torch
from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Super
from minimagen.t5 import t5_encode_text, get_encoded_dim
from torch import optim
# Name of the T5 encoder to use
encoder_name = 't5_small'
# Text captions of training images
train_texts = [
'a pepperoni pizza',
'a man riding a horse',
'a Beluga whale',
'a woman rock climbing'
]
# Training images (side length equal to Imagen final output image size)
train_images = torch.randn(4, 3, 64, 64)
# Create the Imagen instance
enc_dim = get_encoded_dim(encoder_name)
unets = (Base(text_embed_dim=enc_dim), Super(text_embed_dim=enc_dim))
imagen = Imagen(unets=unets, image_sizes=(32, 64), timesteps=10)
# Create an optimzier
optimizer = optim.Adam(imagen.parameters())
# Train the U-Nets in Imagen
for j in range(10):
for i in range(len(unets)):
optimizer.zero_grad()
loss = imagen(train_images, texts=train_texts, unet_number=i)
loss.backward()
optimizer.step()
# Sample captions to generate images for
sample_captions = [
'a happy dog',
'a big red house',
'a woman standing on a beach',
'a man on a bike'
]
# Generate images
images = imagen.sample(texts=sample_captions, cond_scale=3., return_pil_images=True)
# Save images
for idx, img in enumerate(images):
img.save(f'Generated_Image_{idx}.png')
Text embeddings and masks can be precomputed, and Unets parameters can be specified rather than using Base and Super:
train_encs, train_mask = t5_encode_text(train_texts, name=encoder_name)
enc_dim = get_encoded_dim(encoder_name)
base_unet = Unet(
dim=32,
text_embed_dim=enc_dim,
cond_dim=64,
dim_mults=(1, 2, 4),
num_resnet_blocks=2,
layer_attns=(False, False, True),
layer_cross_attns=(False, False, True),
attend_at_middle=True
)
super_res_unet = Unet(
dim=32,
text_embed_dim=enc_dim,
cond_dim=512,
dim_mults=(1, 2, 4),
num_resnet_blocks=(2, 4, 8),
layer_attns=(False, False, True),
layer_cross_attns=(False, False, True),
attend_at_middle=False
)
# Create Imagen instance
imagen = Imagen((base_unet, super_res_unet), image_sizes=(32, 64), timesteps=10)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file minimagen-0.0.6.tar.gz.
File metadata
- Download URL: minimagen-0.0.6.tar.gz
- Upload date:
- Size: 37.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4ec5ce6a8eae4f97ce933be018ef46b983f04f58aeb7ded60a0796036e2b4724
|
|
| MD5 |
b2b7ff41e939a5c1de6afeeb422dd447
|
|
| BLAKE2b-256 |
6cbbc5bfc6ed8e86dae0f7f5b3108c82e47a38ed520271e4d0d8b85a70a2eec7
|
File details
Details for the file minimagen-0.0.6-py3-none-any.whl.
File metadata
- Download URL: minimagen-0.0.6-py3-none-any.whl
- Upload date:
- Size: 39.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
346a3563dd2da8d608349a3f0642b93e635d19a625af2a455d56f8872e7d5c64
|
|
| MD5 |
5910ae629e76b9bb009c5e473ae49b89
|
|
| BLAKE2b-256 |
c6fe67e2ef4053d731cfcef4df81512aceaa1d49361aa03541d306d43ce90530
|