Skip to main content

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf

Project description

GCViT: Global Context Vision Transformer

python tensorflow

Open In Colab Open In Kaggle

Tensorflow 2.0 Implementation of GCViT

This library implements GCViT using Tensorflow 2.0 specifically in tf.keras.Model manner to get PyTorch flavor.

Update

Paper Implementation & Explanation **

I have explained the GCViT paper in a Kaggle notebook GCViT: Global Context Vision Transformer, which also includes a detailed implementation of the model from scratch. The notebook provides a comprehensive explanation of each part of the model, with intuition.

Do check it out, especially if you are interested in learning more about GCViT or implementing it yourself. Note that this notebook has won the Kaggle ML Research Award 2022.

Model

  • Architecture:
  • Local Vs Global Attention:

Result

Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on ImageNetV2-Test data,

Model Acc@1 Acc@5 #Params
GCViT-XXTiny 0.663 0.873 12M
GCViT-XTiny 0.685 0.885 20M
GCViT-Tiny 0.708 0.899 28M
GCViT-Small 0.720 0.901 51M
GCViT-Base 0.731 0.907 90M
GCViT-Large 0.734 0.913 202M

Installation

pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf

Usage

Load model using following codes,

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)

Any input size other than 224x224,

from gcvit import GCViTTiny
model = GCViTTiny(input_shape=(512,512,3), pretrain=True, resize_query=True)

Simple code to check model's prediction,

from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])

Prediction:

[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]

For feature extraction:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)

Feature:

(None, 512)

For feature map:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)

Feature map:

(None, 7, 7, 512)

Kaggle Models

These pre-trained models can also be loaded using Kaggle Models. Setting from_kaggle=True will enforce model to load weights from Kaggle Models without downloading, thus can be used without internet in Kaggle.

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True, from_kaggle=True)

Live-Demo

  • For live demo on Image Classification & Grad-CAM, with ImageNet weights, click powered by 🤗 Space and Gradio. here's an example,

Example

For working training example checkout these notebooks on Google Colab Open In Colab & Kaggle Open In Kaggle.

Here is grad-cam result after training on Flower Classification Dataset,

To Do

  • Convert it to multi-backend Keras 3.0
  • Segmentation Pipeline
  • Support for Kaggle Models
  • Remove tensorflow_addons
  • New updated weights have been added.
  • Working training example in Colab & Kaggle.
  • GradCAM showcase.
  • Gradio Demo.
  • Build model with tf.keras.Model.
  • Port weights from official repo.
  • Support for TPU.

Acknowledgement

Citation

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gcvit-1.1.6.tar.gz (19.8 kB view details)

Uploaded Source

Built Distribution

gcvit-1.1.6-py3-none-any.whl (21.2 kB view details)

Uploaded Python 3

File details

Details for the file gcvit-1.1.6.tar.gz.

File metadata

  • Download URL: gcvit-1.1.6.tar.gz
  • Upload date:
  • Size: 19.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for gcvit-1.1.6.tar.gz
Algorithm Hash digest
SHA256 61dedd832162420dc35a727c173f922dec5642aeca6bb93f394c9caf7f922dc2
MD5 42ecc2cc7248cac5f6a80f8ae5b809fc
BLAKE2b-256 f4cabd2515c4eaba86e5262e79880bdae80ad509fbcb721033841eef303046ea

See more details on using hashes here.

File details

Details for the file gcvit-1.1.6-py3-none-any.whl.

File metadata

  • Download URL: gcvit-1.1.6-py3-none-any.whl
  • Upload date:
  • Size: 21.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for gcvit-1.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 4e30e86c207e9d54100721401ba6d7328876f9ed14d6b214a8f783ccedab8bba
MD5 fa5c2a939ba40486ac2606ddcdfd358d
BLAKE2b-256 968c7426ef028c8b41b44452ccd2274374364f4bd1d14e1802e6d2d313bc98ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page