Skip to main content

RuCLIP: Zero-shot image classification models for Russian language

Project description

RuCLIP

Zero-shot image classification model for Russian language


RuCLIP (Russian Contrastive Language–Image Pretraining) is a multimodal model for obtaining images and text similarities and rearranging captions and pictures. RuCLIP builds on a large body of work on zero-shot transfer, computer vision, natural language processing and multimodal learning. This repo has the prototypes model of OpenAI CLIP's Russian version following this paper.

Models

Installing

pip install ruclip==0.0.2

Usage

Open In Colab Standart RuCLIP API

Open In Colab RuCLIP + SberVqgan

Open In Colab ONNX example

Init models

import ruclip

device = 'cuda'
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=device)

Zero-Shot Classification [Minimal Example]

import torch
import base64
import requests
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

# prepare images
bs4_urls = requests.get('https://raw.githubusercontent.com/sberbank-ai/ru-dolph/master/pics/pipelines/cats_vs_dogs_bs4.json').json()
images = [Image.open(BytesIO(base64.b64decode(bs4_url))) for bs4_url in bs4_urls]

# prepare classes
classes = ['кошка', 'собака']
templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное']

# predict
predictor = ruclip.Predictor(clip, processor, device, bs=8, templates=templates)
with torch.no_grad():
    text_latents = predictor.get_text_latents(classes)
    pred_labels = predictor.run(images, text_latents)

# show results
f, ax = plt.subplots(2,4, figsize=(12,6))
for i, (pil_img, pred_label) in enumerate(zip(images, pred_labels)):
    ax[i//4, i%4].imshow(pil_img)
    ax[i//4, i%4].set_title(classes[pred_label])

Cosine similarity Visualization Example

Softmax Scores Visualization Example

Linear Probe and ZeroShot Correlation Results

Linear Probe Example

train = CIFAR100(root, download=True, train=True)
test = CIFAR100(root, download=True, train=False)

with torch.no_grad():
    X_train = predictor.get_image_latents((pil_img for pil_img, _ in train)).cpu().numpy()
    X_test = predictor.get_image_latents((pil_img for pil_img, _ in test)).cpu().numpy()
    y_train, y_test = np.array(train.targets), np.array(test.targets)

clf = LogisticRegression(solver='lbfgs', penalty='l2', max_iter=1000, verbose=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = np.mean((y_test == y_pred).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

>>> Accuracy = 75.680

Performance

We have evaluated the performance zero-shot image classification on the following datasets:

Dataset ruCLIP Base [vit-base-patch32-224] ruCLIP Base [vit-base-patch16-224] ruCLIP Large [vit-large-patch14-224] ruCLIP Base [vit-base-patch32-384] ruCLIP Large [vit-large-patch14-336] ruCLIP Base [vit-base-patch16-384] CLIP [vit-base-patch16-224] original + OPUS-MT CLIP [vit-base-patch16-224] original
Food101, acc 0.505 0.552 0.597 0.642 0.712💥 0.689 0.664 0.883
CIFAR10, acc 0.818 0.810 0.878 0.862 0.906💥 0.845 0.859 0.893
CIFAR100, acc 0.504 0.496 0.511 0.529 0.591 0.569 0.603💥 0.647
Birdsnap, acc 0.115 0.117 0.172 0.161 0.213💥 0.195 0.126 0.396
SUN397, acc 0.452 0.462 0.484 0.510 0.523💥 0.521 0.447 0.631
Stanford Cars, acc 0.433 0.487 0.559 0.572 0.659💥 0.626 0.567 0.638
DTD, acc 0.380 0.401 0.370 0.390 0.408 0.421💥 0.243 0.432
MNIST, acc 0.447 0.464 0.337 0.404 0.242 0.478 0.559💥 0.559
STL10, acc 0.932 0.932 0.934 0.946 0.956 0.964 0.967💥 0.970
PCam, acc 0.501 0.505 0.520 0.506 0.554 0.501 0.603💥 0.573
CLEVR, acc 0.148 0.128 0.152 0.188 0.142 0.132 0.240💥 0.240
Rendered SST2, acc 0.489 0.527 0.529 0.508 0.539💥 0.525 0.484 0.484
ImageNet, acc 0.375 0.401 0.426 0.451 0.488💥 0.482 0.392 0.638
FGVC Aircraft, mean-per-class 0.033 0.043 0.046 0.053 0.075 0.046 0.220💥 0.244
Oxford Pets, mean-per-class 0.560 0.595 0.604 0.587 0.546 0.635💥 0.507 0.874
Caltech101, mean-per-class 0.786 0.775 0.777 0.834 0.835💥 0.835💥 0.792 0.883
Flowers102, mean-per-class 0.401 0.388 0.455 0.449 0.517💥 0.452 0.357 0.697
Hateful Memes, roc-auc 0.564 0.516 0.530 0.537 0.519 0.543 0.579💥 0.589

And for linear-prob evaluation:

Dataset ruCLIP Base [vit-base-patch32-224] ruCLIP Base [vit-base-patch16-224] ruCLIP Large [vit-large-patch14-224] ruCLIP Base [vit-base-patch32-384] ruCLIP Large [vit-large-patch14-336] ruCLIP Base [vit-base-patch16-384] CLIP [vit-base-patch16-224] original
Food101 0.765 0.827 0.840 0.851 0.896💥 0.890 0.901
CIFAR10 0.917 0.922 0.927 0.934 0.943💥 0.942 0.953
CIFAR100 0.716 0.739 0.734 0.745 0.770 0.773💥 0.808
Birdsnap 0.347 0.503 0.567 0.434 0.609 0.612💥 0.664
SUN397 0.683 0.721 0.731 0.721 0.759💥 0.758 0.777
Stanford Cars 0.697 0.776 0.797 0.766 0.831 0.840💥 0.866
DTD 0.690 0.734 0.711 0.703 0.731 0.749💥 0.770
MNIST 0.963 0.974💥 0.949 0.965 0.949 0.971 0.989
STL10 0.957 0.962 0.973 0.968 0.981💥 0.974 0.982
PCam 0.827 0.823 0.791 0.835 0.807 0.846💥 0.830
CLEVR 0.356 0.360 0.358 0.308 0.318 0.378💥 0.604
Rendered SST2 0.603 0.655 0.651 0.651 0.637 0.661💥 0.606
FGVC Aircraft 0.254 0.312 0.290 0.283 0.341 0.362💥 0.604
Oxford Pets 0.774 0.820 0.819 0.730 0.753 0.856💥 0.931
Caltech101 0.904 0.917 0.914 0.922 0.937💥 0.932 0.956
HatefulMemes 0.545 0.568 0.563 0.581 0.585💥 0.578 0.645

Also, we have created speed comparison based on CIFAR100 dataset using Nvidia-V100 for evaluation:

ruclip-vit-base-patch32-224 ruclip-vit-base-patch16-224 ruclip-vit-large-patch14-224 ruclip-vit-base-patch32-384 ruclip-vit-large-patch14-336 ruclip-vit-base-patch16-384
iter/sec 308.84 💥 155.35 49.95 147.26 22.11 61.79

Authors

Supported by

Social Media

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ruclip-0.0.2.tar.gz (18.4 kB view details)

Uploaded Source

Built Distribution

ruclip-0.0.2-py3-none-any.whl (14.7 kB view details)

Uploaded Python 3

File details

Details for the file ruclip-0.0.2.tar.gz.

File metadata

  • Download URL: ruclip-0.0.2.tar.gz
  • Upload date:
  • Size: 18.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.9

File hashes

Hashes for ruclip-0.0.2.tar.gz
Algorithm Hash digest
SHA256 eaf18c3192ab6f5bed697676dfb460d00068c0d0e5461764dbbfbc76470822e6
MD5 60f1b008d8f058a6fcebfb35b8111b86
BLAKE2b-256 6fe58e09d95e944d46eabdaafea6aab494e311edece1c5f5631e212676c3ca5b

See more details on using hashes here.

File details

Details for the file ruclip-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: ruclip-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 14.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.9

File hashes

Hashes for ruclip-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9adf6504d70a917c0b09798871df0b757709cc79dc1f66df03c3776cbf20f4ce
MD5 f604dc601bb9132ac94177cc3f927c8e
BLAKE2b-256 e94acdb84c0ad2abc67edb06995c8fcc83a3e8244aed63e278d2acd5b5081af5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page