Skip to main content

Pytorch implementation of the CLIP guided bbox refinement for Object Detection.

Project description

pytorch_clip_bbox: Implementation of the CLIP guided bbox refinement for Object Detection.

Pytorch based library to filter predicted bounding boxes using text/image user's prompts.

Usually, object detection models trains to detect common classes of objects such as "car", "person", "cup", "bottle". But sometimes we need to detect more complex classes such as "lady in the red dress", "bottle of whiskey", or "where is my red cup" instead of "person", "bottle", "cup" respectively. One way to solve this problem is to train more complex detectors that can detect more complex classes, but we propose to use text-driven object detection that allows detecting any complex classes that can be described by natural language. This library is written to filter predicted bounding boxes using text/image descriptions of complex classes.

Install package

pip install pytorch_clip_bbox

Install the latest version

pip install --upgrade git+https://github.com/bes-dev/pytorch_clip_bbox.git

Features

  • The library supports multiple prompts (images or texts) as targets for filtering.
  • The library automatically detects the language of the input text, and multilingual translate it via google translate.
  • The library supports the original CLIP model by OpenAI and ruCLIP model by SberAI.
  • Simple integration with different object detection models.

Usage

Simple example to integrate pytorch_clip_bbox with yolov5 model

$ pip install -r examples/yolov5_requirements.txt
import argparse
import cv2
import torch
from pytorch_clip_bbox import ClipBBOX

def extract_boxes(detections):
    boxes = []
    for i in range(detections.xyxy[0].size(0)):
        x1, y1, x2, y2, confidence, idx = detections.xyxy[0][i]
        boxes.append([int(x1), int(y1), int(x2-x1), int(y2-y1)])
    return boxes

def main(args):
    # build detector
    detector = torch.hub.load("ultralytics/yolov5", "yolov5s").to(args.device)
    clip_bbox = ClipBBOX(clip_type=args.clip_type).to(args.device)
    # add prompts
    if args.text_prompt is not None:
        clip_bbox.add_prompt(text=args.text_prompt)
    if args.image_prompt is not None:
        image = cv2.cvtColor(cv2.imread(args.image_prompt), cv2.COLOR_BGR2RGB)
        image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
        image = img / 255.0
        clip_bbox.add_prompt(image=image)
    image = cv2.cvtColor(cv2.imread(args.image), cv2.COLOR_BGR2RGB)
    detections = detector(image)
    boxes = extract_boxes(detections)
    filtered_boxes = clip_bbox(image, top_k=args.top_k)
    screen = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    for box in filtered_boxes:
        x, y, w, h = box["rect"]
        cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 4)
    if args.output_image is None:
        cv2.imshow("image", images)
        cv2.waitKey()
    else:
        cv2.imwrite(args.output_image, image)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--image", type=str, help="Input image.")
    parser.add_argument("--device", type=str, default="cuda:0", help="inference device.")
    parser.add_argument("--text-prompt", type=str, default=None, help="Text prompt.")
    parser.add_argument("--image-prompt", type=str, default=None, help="Image prompt.")
    parser.add_argument("--clip-type", type=str, default="clip_vit_b32", help="Type of CLIP model [ruclip, clip_vit_b32, clip_vit_b16].")
    parser.add_argument("--top-k", type=int, default=1, help="top_k predictions will be returned.")
    parser.add_argument("--output-image", type=str, default=None, help="Output image name.")
    args = parser.parse_args()
    main(args)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_clip_bbox-2021.12.24.0-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_clip_bbox-2021.12.24.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_clip_bbox-2021.12.24.0-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.0 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.10

File hashes

Hashes for pytorch_clip_bbox-2021.12.24.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3ef7e1047a313cc395cf81e69a3e0ce66bdc197d76a854b8219e0688f366e280
MD5 47191abc54cadebf33fe06b1d994ae57
BLAKE2b-256 e05251c016337b04fadae6a410e6087f3a38c8ca3cbeb9835fb05562e036f115

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page