Skip to main content

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf

Project description

GCViT: Global Context Vision Transformer

python tensorflow

Open In Colab Open In Kaggle

Tensorflow 2.0 Implementation of GCViT

This library implements GCViT using Tensorflow 2.0 specifally in tf.keras.Model manner to get PyTorch flavor.

Model

  • Architecture:
  • Local Vs Global Attention:

Result

The reported result in the paper is shown in the figure. But due to issues in the codebase actual result differs from the reported result.

Installation

pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf

Usage

Load model using following codes,

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)

Simple code to check model's prediction,

from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])

Prediction:

[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]

For feature extraction:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifer(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)

Feature:

(None, 7, 7, 512)

Example

  • For live demo on Image Classification & Grad-CAM, with ImageNet weights, click powered by 🤗 Space and Gradio. here's an example,

To Do

  • Working training example in Colab & Kaggle.
  • GradCAM showcase.
  • Gradio Demo.
  • Build model with tf.keras.Model.
  • Port weights from official repo.
  • Support for TPU.

Acknowledgement

Citation

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gcvit-1.0.5.tar.gz (11.5 kB view hashes)

Uploaded Source

Built Distribution

gcvit-1.0.5-py3-none-any.whl (15.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page