Skip to main content

Generate text captions for images from their CLIP embeddings.

Project description

clip-text-decoder

Train an image captioner with 0.30 BLEU score in under an hour! Includes PyTorch model code, example training script, and pre-trained model weights.

Example Predictions

Example captions were computed with the pretrained model mentioned below.

"A man riding a wave on top of a surfboard."

A surfer riding a wave

"A baseball player is swinging a bat at a ball."

Baseball player

"A dog jumping in the air to catch a frisbee."

Dog with frisbee

Installation

Install for easier access to the following objects/classes:

  • clip_text_decoder.dataset.ClipCocoCaptionsDataset
  • clip_text_decoder.model.ClipDecoder
  • clip_text_decoder.model.ClipDecoderInferenceModel

The train.py script will not be available in the installed package, since it's located in the root directory. To train new models, either clone this repository or recreate train.py locally.

Using pip:

pip install clip-text-decoder

From source:

git clone https://github.com/fkodom/clip-text-decoder.git
cd clip-text-decoder
pip install .

NOTE: You'll also need to install openai/CLIP to encode images with CLIP. This is also required by ClipCocoCaptionsDataset to build the captions dataset the first time (cached for subsequent calls).

pip install "clip @ git+https://github.com/openai/CLIP.git"

For technical reasons, the CLIP dependency can't be included in the PyPI package, since it's not an officially published package.

Training

Open In Colab

Launch your own training session using the provided script (train.py):

python train.py --max-epochs 10

Training CLI arguments, along with their default values:

--max-epochs 10  # (int)
--batch-size 32  # (int)
--accumulate-grad-batches 4  # (int)
--precision 16  # (16 or 32)
--seed 0  # (int)

One epoch takes about 5 minutes using a T4 GPU, which is freely available in Google Colab. After about 10 training epochs, you'll reach a BLEU-4 score of roughly 0.30. So in under an hour, you can train an image captioning model that is competitive with (though not quite matching) state-of-the-art accuracy.

TODO: Enable full end-to-end training, including the ClIP image backbone. This will dramatically increase training time, since the image encodings can no longer be pre-computed. But in theory, it should lead to higher overall accuracy of the model.

Inference

The training script will produce a model.zip archive, containing the Tokenizer and trained model parameters. To perform inference with it:

import clip
from PIL import Image
import torch

from clip_text_decoder.model import ClipDecoderInferenceModel

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ClipDecoderInferenceModel.load("path/to/model.zip").to(device)
clip_model, clip_preprocessor = clip.load("ViT-B/32", device=device, jit=False)

# Create a blank dummy image
dummy_image = Image.new("RGB", (224, 224))
preprocessed = clip_preprocessor(dummy_image).to(device)
# Add a batch dimension using '.unsqueeze(0)'
encoded = clip_model.encode_image(preprocessed.unsqueeze(0))
text = model(encoded)

print(text)
# Probably some nonsense, because we used a dummy image.

Pretrained Models

A pretrained CLIP decoder is hosted in my Google Drive, and can easily be downloaded by:

from clip_text_decoder.model import ClipDecoderInferenceModel

model = ClipDecoderInferenceModel.download_pretrained()

To cache the pretrained model locally, so that it's not re-downloaded each time:

model = ClipDecoderInferenceModel.download_pretrained("/path/to/model.zip")

Shortcomings

  • Only works well with COCO-style images. If you go outside the distribution of COCO objects, you'll get nonsense text captions.
  • Relatively short training time. Even within the COCO domain, you'll occasionally see incorrect captions. Quite a few captions will have bad grammar, repetitive descriptors, etc.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clip-text-decoder-1.1.0.tar.gz (7.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

clip_text_decoder-1.1.0-py3-none-any.whl (8.3 kB view details)

Uploaded Python 3

File details

Details for the file clip-text-decoder-1.1.0.tar.gz.

File metadata

  • Download URL: clip-text-decoder-1.1.0.tar.gz
  • Upload date:
  • Size: 7.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for clip-text-decoder-1.1.0.tar.gz
Algorithm Hash digest
SHA256 96d5b44aea1667dcc1c38843879ec930458b7ddf9022428d890fef90d1226ac0
MD5 87f1967e339793db4a46ef3c3ab3a02c
BLAKE2b-256 0a4f915f0bee06a57a2f379f854fd67959133543d10ada1e186600877ce9e1df

See more details on using hashes here.

File details

Details for the file clip_text_decoder-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: clip_text_decoder-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.9

File hashes

Hashes for clip_text_decoder-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f3e75d56784ea5cfb5af48ae3ffd9c0cee83117234a36c5b0ff81ceb44fc9066
MD5 5c851c0cbdbf6924ba459483c3ff3a67
BLAKE2b-256 af80628536a43ed49cff8653970b925fef7bb4b7333a5e76a0515e941738fb06

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page