Skip to main content

Minimal Imagen text-to-image model implementation.

Project description

MinImagen

A Minimal implementation of the Imagen text-to-image model.

For a tutorial on building this model, see here.

Given a caption of an image, Imagen will generate an image that reflects the caption. The model is a simple cascading diffusion model, using a T5 text encoder to encode the captions which conditions a base image generator, and then a sequence of super-resolution models.

In particular, two notable contributions are the developments of:

  1. Noise Conditioning Augmentation, which noises low-resolution conditioning images in the super-resolution models, and
  2. Dynamic Thresholding which helps prevent image saturation at high classifier-free guidance weights.

See How Imagen Actually Works for a detailed explanation of Imagen's operating principles.


Attribution Note

This implementation is largely based on Phil Wang's Imagen implementation.

Installation

$ pip install minimagen

In order to use the training.py file, you will also need to install datasets and nonechucks:

$ pip install nonechucks

Note that MinImagen requires Python3.9 or higher

Documentation

Documentation can be found here

Usage

A minimal usage:

import torch
from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Super
from minimagen.t5 import t5_encode_text, get_encoded_dim
from torch import optim

# Name of the T5 encoder to use
encoder_name = 't5_small'

# Text captions of training images
train_texts = [
    'a pepperoni pizza',
    'a man riding a horse',
    'a Beluga whale',
    'a woman rock climbing'
]

# Training images (side length equal to Imagen final output image size)
train_images = torch.randn(4, 3, 64, 64)

# Create the Imagen instance
enc_dim = get_encoded_dim(encoder_name)
unets = (Base(text_embed_dim=enc_dim), Super(text_embed_dim=enc_dim))
imagen = Imagen(unets=unets, image_sizes=(32, 64), timesteps=10)

# Create an optimzier
optimizer = optim.Adam(imagen.parameters())

# Train the U-Nets in Imagen
for j in range(10):
    for i in range(len(unets)):
        optimizer.zero_grad()
        loss = imagen(train_images, texts=train_texts, unet_number=i)
        loss.backward()
        optimizer.step()

# Sample captions to generate images for
sample_captions = [
    'a happy dog',
    'a big red house',
    'a woman standing on a beach',
    'a man on a bike'
]

# Generate images
images = imagen.sample(texts=sample_captions, cond_scale=3., return_pil_images=True)

# Save images
for idx, img in enumerate(images):
    img.save(f'Generated_Image_{idx}.png')

Text embeddings and masks can be precomputed, and Unets parameters can be specified rather than using Base and Super:

train_encs, train_mask = t5_encode_text(train_texts, name=encoder_name)

enc_dim = get_encoded_dim(encoder_name)

base_unet = Unet(
    dim=32,
    text_embed_dim=enc_dim,
    cond_dim=64,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=2,
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=True
)

super_res_unet = Unet(
    dim=32,
    text_embed_dim=enc_dim,
    cond_dim=512,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 8),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
    attend_at_middle=False
)

# Create Imagen instance
imagen = Imagen((base_unet, super_res_unet), image_sizes=(32, 64), timesteps=10)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minimagen-0.0.6.tar.gz (37.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

minimagen-0.0.6-py3-none-any.whl (39.6 kB view details)

Uploaded Python 3

File details

Details for the file minimagen-0.0.6.tar.gz.

File metadata

  • Download URL: minimagen-0.0.6.tar.gz
  • Upload date:
  • Size: 37.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for minimagen-0.0.6.tar.gz
Algorithm Hash digest
SHA256 4ec5ce6a8eae4f97ce933be018ef46b983f04f58aeb7ded60a0796036e2b4724
MD5 b2b7ff41e939a5c1de6afeeb422dd447
BLAKE2b-256 6cbbc5bfc6ed8e86dae0f7f5b3108c82e47a38ed520271e4d0d8b85a70a2eec7

See more details on using hashes here.

File details

Details for the file minimagen-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: minimagen-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 39.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for minimagen-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 346a3563dd2da8d608349a3f0642b93e635d19a625af2a455d56f8872e7d5c64
MD5 5910ae629e76b9bb009c5e473ae49b89
BLAKE2b-256 c6fe67e2ef4053d731cfcef4df81512aceaa1d49361aa03541d306d43ce90530

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page