Skip to main content

TensorFlow 2.X reimplementation of Global Context Vision Transformers, Ali Hatamizadeh, Hongxu (Danny) Yin, Jan Kautz Pavlo Molchanov.

Project description

GCViT-TensorFlow

TensorFlow 2.X reimplementation of Global Context Vision Transformers Ali Hatamizadeh, Hongxu (Danny) Yin, Jan Kautz Pavlo Molchanov.

  • Exact TensorFlow reimplementation of official PyTorch repo, including timm modules used by authors, preserving models and layers structure.
  • ImageNet pretrained weights ported from PyTorch official implementation.

Table of contents

Abstract

GC ViT achieves state-of-the-art results across image classification, object detection and semantic segmentation tasks. On ImageNet-1K dataset for classification, the tiny, small and base variants of GC ViT with 28M, 51M and 90M, surpass comparably-sized prior art such as CNN-based ConvNeXt and ViT-based Swin Transformer by a large margin. Pre-trained GC ViT backbones in downstream tasks of object detection, instance segmentation, and semantic segmentation using MS COCO and ADE20K datasets outperform prior work consistently, sometimes by large margins.

Alt text

Top-1 accuracy vs. model FLOPs/parameter size on ImageNet-1K dataset. GC ViT achieves new SOTA benchmarks for different model sizes as well as FLOPs, outperforming competing approaches by a significant margin.

Alt text

Architecture of the Global Context ViT. The authors use alternating blocks of local and global context self attention layers in each stage of the architecture.

Results

TensorFlow implementation and ImageNet ported weights have been compared to the official PyTorch implementation on ImageNet-V2 test set.

Configuration Top-1 (Original) Top-1 (Ported) Top-5 (Original) Top-5 (Ported) #Params
GCViT-XXTiny 68.79 68.73 88.52 88.47 12M
GCViT-XTiny 70.97 71 89.8 89.79 20M
GCViT-Tiny 72.93 72.9 90.7 90.7 28M
GCViT-Small 73.46 73.5 91.14 91.08 51M
GCViT-Base 74.13 74.16 91.66 91.69 90M

Mean metrics difference: 3e-4.

Installation

  • Install from PyPI
pip install gcvit-tensorflow
  • Install from Github
pip install git+https://github.com/EMalagoli92/GCViT-TensorFlow
  • Clone the repo and install necessary packages
git clone https://github.com/EMalagoli92/GCViT-TensorFlow.git
pip install -r requirements.txt

Tested on Ubuntu 20.04.4 LTS x86_64, python 3.9.7.

Usage

  • Define a custom GCViT configuration.
from gcvit_tensorflow import GCViT

# Define a custom GCViT configuration
model = GCViT(
    depths=[2, 2, 6, 2],
    num_heads=[2, 4, 8, 16],
    window_size=[7, 7, 14, 7],
    dim=64,
    resolution=224,
    in_chans=3,
    mlp_ratio=3,
    drop_path_rate=0.2,
    data_format="channels_last",
    num_classes=100,
    classifier_activation="softmax",
)
  • Use a predefined GCViT configuration.
from gcvit_tensorflow import GCViT

model = GCViT(configuration="xxtiny")
model.build((None, 224, 224, 3))
print(model.summary())
Model: "xxtiny"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 patch_embed (PatchEmbed)    (None, 56, 56, 64)        45632     
                                                                 
 pos_drop (Dropout)          (None, 56, 56, 64)        0         
                                                                 
 levels/0 (GCViTLayer)       (None, 28, 28, 128)       185766    
                                                                 
 levels/1 (GCViTLayer)       (None, 14, 14, 256)       693258    
                                                                 
 levels/2 (GCViTLayer)       (None, 7, 7, 512)         5401104   
                                                                 
 levels/3 (GCViTLayer)       (None, 7, 7, 512)         5400546   
                                                                 
 norm (LayerNorm_)           (None, 7, 7, 512)         1024      
                                                                 
 avgpool (AdaptiveAveragePoo  (None, 512, 1, 1)        0         
 ling2D)                                                         
                                                                 
 head (Linear_)              (None, 1000)              513000    
                                                                 
=================================================================
Total params: 12,240,330
Trainable params: 11,995,428
Non-trainable params: 244,902
_________________________________________________________________
  • Train from scratch the model.
# Example
model.compile(
    optimizer="sgd",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
model.fit(x, y)
  • Use ported ImageNet pretrained weights
# Example
from gcvit_tensorflow import GCViT

model = GCViT(configuration="base", pretrained=True, classifier_activation="softmax")
y_pred = model(image)

Acknowledgement

Citations

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

License

This work is made available under the MIT License

The pre-trained weights are shared under CC-BY-NC-SA-4.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gcvit_tensorflow-1.2.1.tar.gz (20.4 kB view details)

Uploaded Source

Built Distribution

gcvit_tensorflow-1.2.1-py3-none-any.whl (28.0 kB view details)

Uploaded Python 3

File details

Details for the file gcvit_tensorflow-1.2.1.tar.gz.

File metadata

  • Download URL: gcvit_tensorflow-1.2.1.tar.gz
  • Upload date:
  • Size: 20.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.1

File hashes

Hashes for gcvit_tensorflow-1.2.1.tar.gz
Algorithm Hash digest
SHA256 b55527c393e8a2e2e5706f24b4ee1f382b7e5762f380344623634b22bce20aa3
MD5 3f3f754d760a09ea13d4c07af70b21ee
BLAKE2b-256 3dd010c99d70c2903c1a13b476b893d228a4f8b02a6a165caf405293f1e36b67

See more details on using hashes here.

File details

Details for the file gcvit_tensorflow-1.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for gcvit_tensorflow-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b226db773991472a106aa9bf30b7379618b4cd2f4fedc91f01508438852a2394
MD5 b958f1e39ac0ba6560783cce80a9a202
BLAKE2b-256 eabe52c8d0839e566d0c528d54f810f441d714c02bf9dabd721ad21af8235b26

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page