Skip to main content

Minimal Imagen text-to-image model implementation.

Project description

MinImagen

A Minimal implementation of the Imagen text-to-image model.

For a tutorial on building this model, see here.

Given a caption of an image, Imagen will generate an image that reflects the caption. The model is a simple cascading diffusion model, using a T5 text encoder to encode the captions which conditions a base image generator, and then a sequence of super-resolution models.

In particular, two notable contributions are the developments of:

  1. Noise Conditioning Augmentation, which noises low-resolution conditioning images in the super-resolution models, and
  2. Dynamic Thresholding which helps prevent image saturation at high classifier-free guidance weights.

See How Imagen Actually Works for a detailed explanation of Imagen's operating principles.


Attribution Note

This implementation is largely based on Phil Wang's Imagen implementation.

Installation

$ pip install minimagen

In order to use the training.py file, you will also need to install datasets and nonechucks:

$ pip install nonechucks

Note that MinImagen requires Python3.9 or higher

Documentation

Documentation can be found here

Usage

A minimal usage:

import torch
from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Super
from minimagen.t5 import t5_encode_text, get_encoded_dim
from torch import optim

# Name of the T5 encoder to use
encoder_name = 't5_small'

# Text captions of training images
train_texts = [
    'a pepperoni pizza',
    'a man riding a horse',
    'a Beluga whale',
    'a woman rock climbing'
]

# Training images (side length equal to Imagen final output image size)
train_images = torch.randn(4, 3, 64, 64)

# Create the Imagen instance
enc_dim = get_encoded_dim(encoder_name)
unets = (Base(text_embed_dim=enc_dim), Super(text_embed_dim=enc_dim))
imagen = Imagen(unets=unets, image_sizes=(32, 64), timesteps=10)

# Create an optimzier
optimizer = optim.Adam(imagen.parameters())

# Train the U-Nets in Imagen
for j in range(10):
    for i in range(len(unets)):
        optimizer.zero_grad()
        loss = imagen(train_images, texts=train_texts, unet_number=i)
        loss.backward()
        optimizer.step()

# Sample captions to generate images for
sample_captions = [
    'a happy dog',
    'a big red house',
    'a woman standing on a beach',
    'a man on a bike'
]

# Generate images
images = imagen.sample(texts=sample_captions, cond_scale=3., return_pil_images=True)

# Save images
for idx, img in enumerate(images):
    img.save(f'Generated_Image_{idx}.png')

Text embeddings and masks can be precomputed, and Unets parameters can be specified rather than using Base and Super:

train_encs, train_mask = t5_encode_text(train_texts, name=encoder_name)

enc_dim = get_encoded_dim(encoder_name)

base_unet = Unet(
    dim=32,
    text_embed_dim=enc_dim,
    cond_dim=64,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=2,
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=True
)

super_res_unet = Unet(
    dim=32,
    text_embed_dim=enc_dim,
    cond_dim=512,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 8),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=False
)

# Create Imagen instance
imagen = Imagen((base_unet, super_res_unet), image_sizes=(32, 64), timesteps=10)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minimagen-0.0.6.tar.gz (37.7 kB view hashes)

Uploaded Source

Built Distribution

minimagen-0.0.6-py3-none-any.whl (39.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page