Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf
Project description
GCViT: Global Context Vision Transformer
Tensorflow 2.0 Implementation of GCViT
This library implements GCViT using Tensorflow 2.0 specifally in tf.keras.Model
manner to get PyTorch flavor.
Model
- Architecture:
- Local Vs Global Attention:
Result
The reported result in the paper is shown in the figure. But due to issues in the codebase actual result differs from the reported result.
Installation
pip install -U gcvit
# or
# pip install -U git+https://github.com/awsaf49/gcvit-tf
Usage
Load model using following codes,
from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)
Simple code to check model's prediction,
from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
Prediction:
[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623),
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297),
('n02883205', 'bow_tie', 0.00042479983)]
To Do
- Working training example.
- GradCAM showcase.
- Gradio Demo.
- Build model with
tf.keras.Model
. - Port weights from official repo.
Acknowledgement
- GCVit (Official)
- Swin-Transformer-TF
- tfgcvit
- keras_cv_attention_models
Citation
@article{hatamizadeh2022global,
title={Global Context Vision Transformers},
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
journal={arXiv preprint arXiv:2206.09959},
year={2022}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gcvit-1.0.2.tar.gz
(9.5 kB
view hashes)
Built Distribution
gcvit-1.0.2-py3-none-any.whl
(13.3 kB
view hashes)