Generate text captions for images from their CLIP embeddings.
Project description
clip-text-decoder
Train an image captioner with 0.323 BLEU on COCO Captions in under one hour! (0.352 BLEU with beam search 🙂)
Generates text captions for images from their embeddings. Now includes BLIP as an available vision backbone!
Example Predictions
Computed using the pretrained model mentioned below.
"A man riding a wave on top of a surfboard."
"A baseball player is swinging a bat at a ball."
"A dog jumping in the air to catch a frisbee."
Installation
Using pip
:
pip install "clip @ git+https://github.com/openai/CLIP.git"
pip install "lavis @ git+https://github.com/salesforce/LAVIS.git"
pip install clip-text-decoder
From source:
pip install "clip @ git+https://github.com/openai/CLIP.git"
pip install "lavis @ git+https://github.com/salesforce/LAVIS.git"
git clone https://github.com/fkodom/clip-text-decoder.git
cd clip-text-decoder
pip install .
Inference
Pretrained Model
from PIL import Image
import torch
from clip_text_decoder.model import ImageCaptionInferenceModel
model = ImageCaptionInferenceModel.download_pretrained()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
image = Image.open("path/to/image.jpeg")
# The beam_size argument is optional. Larger beam_size is slower, but has
# slightly higher accuracy. Recommend using beam_size <= 3.
caption = model(image, beam_size=1)
To cache the pretrained model locally, so that it's not re-downloaded each time:
model = ImageCaptionInferenceModel.download_pretrained("path/to/model.pt")
Custom Trained Model
Training produces a model.pt
archive, containing a Tokenizer
and model parameters. To reload the trained inference model:
from clip_text_decoder.model import ImageCaptionInferenceModel
model = ImageCaptionInferenceModel.load("path/to/model.pt").to(device)
# Load image and get predictions like above...
Ablation: Beam Size
Measuring the BLEU-4 score for different beam_size
arguments. By default, the inference model uses a beam size of 1:
from clip_text_decoder.model import ImageCaptionInferenceModel
model = ImageCaptionInferenceModel.load("path/to/model.pt")
caption = model(image, beam_size=1)
Using larger beam_size
gives better BLEU score with a trade-off of slower inference speeds. The metrics below were collected from the same model, which uses a BLIP vision backbone and was trained for 10 epochs (roughly 1 hour on a T4 GPU):
Beam size | BLEU-4 |
---|---|
1 (default) | 0.323 |
2 | 0.343 |
3 | 0.350 |
4 | 0.352 |
Training
Launch your own training session using train.py
:
python train.py --max-epochs 10
Training CLI arguments, along with their default values:
--vision-backbone blip:base # (str)
--language-model distilgpt2 # (str)
--max-epochs 10 # (int)
--beam-size 1 # (int)
--batch-size 32 # (int)
--accumulate-grad-batches 4 # (int)
--precision 16 # (16 or 32)
--seed 0 # (int)
One epoch takes about 5-6 minutes using a T4 GPU, which is usually free in Google Colab (depending on availability). After about 10 training epochs, you'll reach a BLEU-4 score just over 0.30 (without beam search). So, in under an hour, you can train a pretty good image captioning model. 😎
Notes
BLEU doesn't increase much beyond 1 hour of training. Training and validation loss will continue to decrease, but the resulting image captions are effectively equivalent.
This appears to be a limitation of the image embeddings, rather than a limitation of the language model. Changing the vision backbone gives the biggest improvement in BLEU score. (BLIP gets 5-10% better BLEU than CLIP backbones using the same language model head.) Larger language models (e.g. GPT-2 Large) don't improve the BLEU score by much.
TODO
- Plan to train on Conceptual Captions for more generic image captioning.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for clip_text_decoder-1.4.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a76a7961ded8cc97a74ec2f9638813ff62062d1f7f2655e50e8996a83ff49e8d |
|
MD5 | 3c9fe3c3e058bb432dc44e35098ffd9d |
|
BLAKE2b-256 | f040c5205d4d68d52767fec81e1beecd02553489c80d8845640d4412803206d1 |