GEM
Project description
GEM
Grounding Everything: Emerging Localization Properties in Vision-Language Transformers
Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne
Vision-Language foundation models have shown remarkable performance in various zero-shot settings such as image retrieval, classification, or captioning. But so far, those models seem to fall behind when it comes to zero-shot localization of referential expressions and objects in images.
GEM allows a training-free adaptation of Vision-Language models (e.i., CLIP ...) to perform zero-shot open-vocabulary segmentation. The training-free adaptation allows to fully conserve the vocabulary learned by the Vision-Language model during its pretraing, thus allowing the segmentation of uncommon classes (e.g. Elon Musk/Mark Zuckerberg /Jeff Besos).
:hammer: Installation
gem
library can be simply installed via pip:
$ pip install gem_torch
Demo
- Try out our web demo on HuggingFace Spaces
- Run the demo on Google Colab:
- Run
test_examples.py
for a usage example.
To run the gradio app locally, first install gradio and then run app.py
:
$ pip install gradio
$ python app.py
Usage
To see which pretrained models is available use the following code snippet:
import gem
gem.available_models()
Single Image
To process a single image and multiple text prompts use the following code snippet:
import torch
import gem
import requests
from PIL import Image
model_name = 'ViT-B/16' # 'ViT-B-16-quickgelu'
pretrained = 'openai' # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
preprocess = gem.get_gem_img_transform()
gem_model = gem.create_gem_model(model_name=model_name,
pretrained=pretrained,
device=device)
# load image and text
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = preprocess(
Image.open(requests.get(url, stream=True).raw)
).unsqueeze(0).to(device)
text = ['cat', 'remote control']
with torch.no_grad():
logits = gem_model(image, text) # [B, num_prompt, W, H]
gem_model.visualize(image, text, logits) # (optional visualization)
Batched Inference
To process a batch of images with different number of prompts per image, one must use the batched_forward()
function of gem_model
:
import torch
import gem
import requests
from PIL import Image
model_name = 'ViT-B/16' # 'ViT-B-16-quickgelu'
pretrained = 'openai' # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
pretrained=pretrained,
device=device)
preprocess = gem.get_gem_img_transform()
# load image and text
urls = [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"https://cdn.vietnambiz.vn/171464876016439296/2021/7/11/headshots16170695297430-1626006880779826347793.jpg",
"https://preview.redd.it/do-you-think-joker-should-be-unpredictable-enough-to-put-up-v0-6a2ax4ngtlaa1.jpg?auto=webp&s=f8762e6a1b40642bcae5900bac184fc597131503",
]
texts = [
['remote control', 'cat'],
['elon musk', 'mark zuckerberg', 'jeff bezos', 'bill gates'],
['batman', 'joker', 'shoe', 'belt', 'purple suit'],
] # note that the number of prompt per image can be different
# download images + convert to PIL.Image
images_pil = [Image.open(requests.get(url, stream=True).raw) for url in urls]
images = torch.stack([preprocess(img) for img in images_pil]).to(device)
with torch.no_grad():
# return list with logits of size [1, num_prompt, W, H]
logits_list = gem_model.batched_forward(images, texts)
for i, logits in enumerate(logits_list): # (optional visualization)
gem_model.visualize(images[i], texts[i], logits)
API
The library provides the following methods:
gem.create_gem_model(model_name, pretrained, device, ...)
:- Returns
model_name
Vision Language model withpretrained
weights loaded and GEM applied. One can also specifygem_depth
,ss_attn_iter
andss_attn_temp
parameters to respectively control GEM's depth, self-self attention number of iteration and temperature (see paper for more details).
- Returns
gem.get_gem_img_transform(img_size)
:- takes in a PIL.Image and returns a torch.Tensor. This can be used as input to the model.
gem.visualize(image, prompts, logits, alpha=0.6, save_path=None)
:- Takes in a PIL.Image or a torch.Tensor, as well as the list of text prompt and the logits outputed by gem and plot the gem's heatmaps for each prompt. Alternatively, the heatmaps cam be saved by specifying the saving path
save_path
. One can also change the transparence of the heatmps via theaplha=0.6
argument.
- Takes in a PIL.Image or a torch.Tensor, as well as the list of text prompt and the logits outputed by gem and plot the gem's heatmaps for each prompt. Alternatively, the heatmaps cam be saved by specifying the saving path
By default, the models loaded by gem.create_gem_model()
returns logits outputed by GEM, but can also return the logits of the original Vision Language model (it can be useful for visualization). To do so, set return_ori=True
.
More Examples
Semantic Segmentation
For the semantic segmentation task, given a list of foreground class names, one must predict a 2D map where each location is the id of the predicted class. Depending on the dataset, we may also want to predict a background
class. However, the textual description "a photo of a background"
is not descriptive of what the background is composed of. Hence, we propose to use the following code method:
import torch
import gem
import requests
from PIL import Image
model_name = 'ViT-B/16' # 'ViT-B-16-quickgelu'
pretrained = 'openai' # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
pretrained=pretrained,
device=device)
preprocess = gem.get_gem_img_transform()
predict_background = True # whether the background is predicted
if predict_background:
threshold = 0.85 # the threshold depends on the number of classes
# load image and text
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
class_names = ['airplane', 'cat', 'dog', '...'] # foreground class names
with torch.no_grad():
logits = gem_model(image, class_names) # [1, num_class, W, H]
pred = logits.argmax(dim=1)
if predict_background:
pred = pred + 1 # we assume the background's index is 0
probs = logits.softmax(dim=1)
max_prob = probs.max(dim=1)[0]
pred[probs < threshold] = 0 # if the max prob is lower than the threshold the background is predicted
Note that threshold
depends on the number of classes and should be determined via a hyperparameter sweep.
Dataset
gem
can also be used with regular pytorch dataset.
import torch
import gem
from PIL import Image
from torchvision.datasets import VOCSegmentation
model_name = 'ViT-B/16' # 'ViT-B-16-quickgelu'
pretrained = 'openai' # 'metaclip_400m'
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model and image transform
gem_model = gem.create_gem_model(model_name=model_name,
pretrained=pretrained,
device=device)
preprocess = gem.get_gem_img_transform()
predict_background = True # whether the background is predicted
if predict_background:
threshold = 0.85 # the threshold depends on the number of classes
# load dataset
root = './data' # path to save the dataset
dataset = VOCSegmentation(root=root, image_set='val', download=True, transform=preprocess)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)
class_names = ['airplane', 'cat', 'dog', '...'] # foreground class names
with torch.no_grad():
for (image, _) in dataloader:
logits = gem_model(image, class_names) # [1, num_class, W, H]
pred = logits.argmax(dim=1)
if predict_background:
pred = pred + 1 # we assume the background's index is 0
probs = logits.softmax(dim=1)
max_prob = probs.max(dim=1)[0]
pred[probs < threshold] = 0 # if the max prob is lower than the threshold the background is predicted
:star: Acknowledgement
This code is build as wrapper around OpenCLIP library from LAION, visit their repo for more vision-language models. This project takes inspiration from CLIP and CLIPSurgery, please visit their repository. This repo also uses einops as well and take inspiration from CLIP and CLIPSurgery repository.
:books: Citation
If you find this repository useful, please consider citing our work :pencil: and giving a star :star2: :
@article{bousselham2023gem,
title={Grounding Everything: Emerging Localization Properties in Vision-Language Transformers},
author={Walid Bousselham, Felix Petersen, Vittorio Ferrari, Hilde Kuehne},
journal={arXiv preprint arXiv:2312.00878},
year={2023}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file gem_torch-1.0.1.tar.gz
.
File metadata
- Download URL: gem_torch-1.0.1.tar.gz
- Upload date:
- Size: 14.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 560ff1f51101e9499042b976289cb20cd6ebc5950186d3cfa9df5a0f987c21f8 |
|
MD5 | d3d903ca692aac434bc80b6a9d0d7863 |
|
BLAKE2b-256 | bf6a5396e1494e84fad411d060c01f969d3c679ea0d2cb3638dc534ca23efda9 |
File details
Details for the file gem_torch-1.0.1-py3-none-any.whl
.
File metadata
- Download URL: gem_torch-1.0.1-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e9e18c9b3bcb3a1cd1827ce8d95415a0b3d21e19d87ef083be11ccc81dff4f11 |
|
MD5 | 199dc3b14e4fcf7734f62896132d8c79 |
|
BLAKE2b-256 | 4c39f362a75f13104011ce460fa553b79395edd6fdadedfd7c721454ba71a789 |