Skip to main content

Minimal Imagen text-to-image model implementation.

Project description

MinImagen

A Minimal implementation of the Imagen text-to-image model.



See Build Your Own Imagen Text-to-Image Model for a tutorial on how to build MinImagen.

See How Imagen Actually Works for a detailed explanation of Imagen's operating principles.


Given a caption of an image, the text-to-image model Imagen will generate an image that reflects the scene described by the caption. The model is a cascading diffusion model, using a T5 text encoder to generate a caption encoding which conditions a base image generator and then a sequence of super-resolution models through which the output of the base image generator is passed.

In particular, two notable contributions are the developments of:

  1. Noise Conditioning Augmentation, which noises low-resolution conditioning images in the super-resolution models, and
  2. Dynamic Thresholding which helps prevent image saturation at high classifier-free guidance weights.

Table of Contents


Attribution Note

This implementation is largely based on Phil Wang's Imagen implementation.


Installation

To install MinImagen, run the following command in the terminal:

$ pip install minimagen

Note that MinImagen requires Python3.9 or higher


Documentation

See the MinImagen Documentation to learn more about the package.


Usage - Command Line

If you have cloned this repo, you can use the provided scripts to get started with MinImagen.


main.py

For the most basic usage, simply enter the MinImagen directory and run the following in the terminal:

$ python main.py

This will create a small MinImagen instance and train it on a tiny amount of data, and then use this MinImagen instance to generate an image.

After running the script, you will see a directory called training_<TIMESTAMP>.

  1. This directory is called a Training Directory and is generated when training a MinImagen instance.
  2. It contains information about the configuration (parameters subdirectory), and contains the model checkpoints (state_dicts and tmp directories).
  3. It also contains a training_progress.txt file that records training progress.

You will also see a directory called generated_images_<TIMESTEP>.

  1. This directory contains a folder of images generated by the model (generated_images).
  2. It also contains captions.txt files, which documents the captions that were input to get the images (where the line index of a given caption corresponds to the image number in the generated_iamges folder).
  3. Finally, this directory also contains imagen_training_directory.txt, which specifies the name of the Training Directory used to load the MinImagen instance / generate images.

train.py

main.py simply runs train.py and inference.py in series, the former to train the model and the latter to generate the image.

To train a model, simply run train.py and specify relevant command line arguments. The possible arguments are:

  • --PARAMETERS or -p, which specifies a directory that specifies the MinImagen configuration to use. It should be structured like a parameters subdirectory within a Training Directory (example in parameters).
  • --NUM_WORKERS" or -n, which specifies the number of workers to use for the DataLoaders.
  • --BATCH_SIZE or -b, which specifies the batch size to use during training.
  • --MAX_NUM_WORDS or -mw, which specifies the maximum number of words allowed in a caption.
  • --IMG_SIDE_LEN or -s, specifies the final side length of the square images the MinImagen will output.
  • --EPOCHS or -e, which specifies the number of training epochs.
  • --T5_NAME -t5, which specifies the name of T5 encoder to use.
  • --TRAIN_VALID_FRAC or -f, which specifies the fraction of dataset to use for training (vs. validation).
  • --TIMESTEPS or -t, which specifies the number of timesteps in Diffusion Process.
  • --OPTIM_LR or -lr, which specifies the learning rate for Adam optimizer.
  • --ACCUM_ITER or -ai, which specifies the number of batches to accumulate for gradient accumulation.
  • --CHCKPT_NUM or -cn, which specifies the interval of batches to create a temporary model checkpoint at during training.
  • --VALID_NUM or -vn, which specifies the number of validation images to use. If None, uses full amount from train/valid split. The reason for including this is that, even with an e.g. 0.99 --TRAIN_VALID_FRAC, a prohibitively large number of images could still be left for validation for very large datasets.
  • --RESTART_DIRECTORY or -rd, training directory to load MinImagen instance from if resuming training. A new Training Directory will be created for the training, leaving the previous Training Directory from which the checkpoint is loaded unperturbed.
  • --TESTING or -test, which is used to run the script with a small MinImagen instance and small dataset for testing.

For example, to run a small training using the provided example parameters folder, run the following in the terminal:

python train.py --PARAMETERS ./parameters --BATCH_SIZE 2 --TIMESTEPS 25 --TESTING

After execution, you will see a new training_<TIMESTAMP> Training Directory that contains the files as listed above from the training.


inference.py

To generate images using a model from a Training Directory, we can use inference.py. Simply run inference.py and specify relevant command line arguments. The possible arguments are:

  • --TRAINING_DIRECTORY" or -d, which specifies the training directory from which to load the MinImagen instance for inference.
  • --CAPTIONS or -c, which specifies either (a) a single caption to generate an image for, or (b) a filepath to a .txt file that contains a list of captions to generate images for, where each caption is on a new line.

For example, to generate images for the example captions provided in captions.txt using the model generated from the above training line, simply run

python inference.py -CAPTIONS captions.txt --TRAINING_DIRECTORY training_<TIMESTAMP>    

where TIMESTAMP is replaced with the appropriate value from your training.


Usage - Package

Training

A minimal training script using the minimagen package is shown below. See train.py for a more built-up version of the below code.

import os
from datetime import datetime

import torch.utils.data
from torch import optim

from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest
from minimagen.generate import load_minimagen, load_params
from minimagen.t5 import get_encoded_dim
from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \
    create_directory, get_model_params, get_model_size, save_training_info, get_default_args, MinimagenTrain, \
    load_restart_training_parameters, load_testing_parameters

# Get device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Command line argument parser
parser = get_minimagen_parser()
args = parser.parse_args()

# Create training directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dir_path = f"./training_{timestamp}"
training_dir = create_directory(dir_path)

# Replace some cmd line args to lower computational load.
args = load_testing_parameters(args)

# Load subset of Conceptual Captions dataset.
train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=True)

# Create dataloaders
dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}
train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)

# Use small U-Nets to lower computational load.
unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]
unets = [Unet(**unet_params).to(device) for unet_params in unets_params]

# Specify MinImagen parameters
imagen_params = dict(
    image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),
    timesteps=args.TIMESTEPS,
    cond_drop_prob=0.15,
    text_encoder_name=args.T5_NAME
)

# Create MinImagen from UNets with specified imagen parameters
imagen = Imagen(unets=unets, **imagen_params).to(device)

# Fill in unspecified arguments with defaults to record complete config (parameters) file
unets_params = [{**get_default_args(Unet), **i} for i in unets_params]
imagen_params = {**get_default_args(Imagen), **imagen_params}

# Get the size of the Imagen model in megabytes
model_size_MB = get_model_size(imagen)

# Save all training info (config files, model size, etc.)
save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)

# Create optimizer
optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)

# Train the MinImagen instance
MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)

Image Generation

A minimal inference script using the minimagen package is shown below. See inference.py for a more built-up version of the below code.

from argparse import ArgumentParser
from minimagen.generate import load_minimagen, sample_and_save

# Command line argument parser
parser = ArgumentParser()
parser.add_argument("-d", "--TRAINING_DIRECTORY", dest="TRAINING_DIRECTORY", help="Training directory to use for inference", type=str)
args = parser.parse_args()

# Specify the caption(s) to generate images for
captions = ['a happy dog']

# Use `sample_and_save` to generate and save the iamges
sample_and_save(captions, training_directory=args.TRAINING_DIRECTORY)



# Alternatively, rather than specifying a Training Directory, you can input just a MinImagen instance to use for image generation.
# In this case, information about the MinImagen instance used to generate the images will not be saved.
minimagen = load_minimagen(args.TRAINING_DIRECTORY)
sample_and_save(captions, minimagen=minimagen)    

To see more of what MinImagen has to offer, or to get additional details on the scripts above, check out the MinImagen Documentation

Additional Resources

Socials

  • Follow us on Twitter for more Deep Learning content.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minimagen-0.0.9.tar.gz (43.9 kB view hashes)

Uploaded Source

Built Distribution

minimagen-0.0.9-py3-none-any.whl (43.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page