Skip to main content

TensorFlow 2.X reimplementation of CvT: Introducing Convolutions to Vision Transformers, Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.

Project description

CvT-TensorFlow

TensorFlow 2.X reimplementation of CvT: Introducing Convolutions to Vision Transformers, Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.

  • Exact TensorFlow reimplementation of official PyTorch repo, including timm modules used by authors, preserving models and layers structure.
  • ImageNet pretrained weights ported from PyTorch official implementation.

Table of contents

Abstract

Convolutional vision Transformers (CvT), improves Vision Transformers (ViT) in performance and efficienty by introducing convolutions into ViT to yield the best of both designs. This is accomplished through two primary modifications: a hierarchy of Transformers containing a new convolutional token embedding, and a convolutional Transformer block leveraging a convolutional projection. These changes introduce desirable properties of convolutional neural networks (CNNs) to the ViT architecture (e.g. shift, scale, and distortion invariance) while maintaining the merits of Transformers (e.g. dynamic attention, global context, and better generalization). Moreover the achieved results show that the positional encoding, a crucial component in existing Vision Transformers, can be safely removed in the model, simplifying the design for higher resolution vision tasks.

Alt text

The pipeline of the CvT architecture. (a) Overall architecture, showing the hierarchical multi-stage structure facilitated by the Convolutional Token Embedding layer. (b) Details of the Convolutional Transformer Block, which contains the convolution projection as the first layer.

Results

TensorFlow implementation and ImageNet ported weights have been compared to the official PyTorch implementation on ImageNet-V2 test set.

Models pre-trained on ImageNet-1K

Configuration Resolution Top-1 (Original) Top-1 (Ported) Top-5 (Original) Top-5 (Ported) #Params
CvT-13 224x224 69.81 69.81 89.13 89.13 20M
CvT-13 384x384 71.31 71.31 89.97 89.97 20M
CvT-21 224x224 71.18 71.17 89.31 89.31 32M
CvT-21 384x384 71.61 71.61 89.71 89.71 32M

Models pre-trained on ImageNet-22K

Configuration Resoluton Top-1 (Original) Top-1 (Ported) Top-5 (Original) Top-5 (Ported) #Params
CvT-13 384x284 71.76 71.76 91.39 91.39 20M
CvT-21 384x384 74.97 74.97 92.63 92.63 32M
CvT-W24 384x384 78.15 78.15 94.48 94.48 277M

Max metrics difference: 9e-5.

Installation

  • Install from PyPI
pip install cvt-tensorflow
  • Install from Github
pip install git+https://github.com/EMalagoli92/CvT-TensorFlow
  • Clone the repo and install necessary packages
git clone https://github.com/EMalagoli92/CvT-TensorFlow.git
pip install -r requirements.txt

Tested on Ubuntu 20.04.4 LTS x86_64, python 3.9.7.

Usage

  • Define a custom CvT configuration.
from cvt_tensorflow import CvT

# Define a custom CvT configuration
model = CvT(
    in_chans=3,
    num_classes=1000,
    classifier_activation="softmax",
    data_format="channels_last",
    spec={
        "INIT": "trunc_norm",
        "NUM_STAGES": 3,
        "PATCH_SIZE": [7, 3, 3],
        "PATCH_STRIDE": [4, 2, 2],
        "PATCH_PADDING": [2, 1, 1],
        "DIM_EMBED": [64, 192, 384],
        "NUM_HEADS": [1, 3, 6],
        "DEPTH": [1, 2, 10],
        "MLP_RATIO": [4.0, 4.0, 4.0],
        "ATTN_DROP_RATE": [0.0, 0.0, 0.0],
        "DROP_RATE": [0.0, 0.0, 0.0],
        "DROP_PATH_RATE": [0.0, 0.0, 0.1],
        "QKV_BIAS": [True, True, True],
        "CLS_TOKEN": [False, False, True],
        "QKV_PROJ_METHOD": ["dw_bn", "dw_bn", "dw_bn"],
        "KERNEL_QKV": [3, 3, 3],
        "PADDING_KV": [1, 1, 1],
        "STRIDE_KV": [2, 2, 2],
        "PADDING_Q": [1, 1, 1],
        "STRIDE_Q": [1, 1, 1],
    },
)
  • Use a predefined CvT configuration.
from cvt_tensorflow import CvT

model = CvT(
    configuration="cvt-21", data_format="channels_last", classifier_activation="softmax"
)
model.build((None, 224, 224, 3))
print(model.summary())
Model: "cvt-21"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 stage0 (VisionTransformer)  multiple                  62080     
                                                                 
 stage1 (VisionTransformer)  multiple                  1920576   
                                                                 
 stage2 (VisionTransformer)  ((None, 384, 14, 14),     29296128  
                              (None, 1, 384))                    
                                                                 
 norm (LayerNorm_)           (None, 1, 384)            768       
                                                                 
 head (Linear_)              (None, 1000)              385000    
                                                                 
 pred (Activation)           (None, 1000)              0         
                                                                 
=================================================================
Total params: 31,664,552
Trainable params: 31,622,696
Non-trainable params: 41,856
_________________________________________________________________
  • Train from scratch the model.
# Example
model.compile(
    optimizer="sgd",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
model.fit(x, y)
  • Use ported ImageNet pretrained weights
# Example
from cvt_tensorflow import CvT

# Use cvt-13-384x384_22k ImageNet pretrained weights
model = CvT(
    configuration="cvt-13",
    pretrained=True,
    pretrained_resolution=384,
    pretrained_version="22k",
    classifier_activation="softmax",
)
y_pred = model(image)

Acknowledgement

CvT (Official PyTorch implementation)

Citations

@article{wu2021cvt,
  title={Cvt: Introducing convolutions to vision transformers},
  author={Wu, Haiping and Xiao, Bin and Codella, Noel and Liu, Mengchen and Dai, Xiyang and Yuan, Lu and Zhang, Lei},
  journal={arXiv preprint arXiv:2103.15808},
  year={2021}
}

License

This work is made available under the MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cvt_tensorflow-1.1.4.tar.gz (21.7 kB view details)

Uploaded Source

Built Distribution

cvt_tensorflow-1.1.4-py3-none-any.whl (24.8 kB view details)

Uploaded Python 3

File details

Details for the file cvt_tensorflow-1.1.4.tar.gz.

File metadata

  • Download URL: cvt_tensorflow-1.1.4.tar.gz
  • Upload date:
  • Size: 21.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.1

File hashes

Hashes for cvt_tensorflow-1.1.4.tar.gz
Algorithm Hash digest
SHA256 daa46df438c0a2a822fe37b5c07fdfb6a5c076f7769ce73512f155aedc950e7d
MD5 f3ca5573cc367576acb97debc0f34818
BLAKE2b-256 5a3b9d2368399895a4d4dfc2d202a09761795d93d07dad10e4e3e6ea0c45f942

See more details on using hashes here.

File details

Details for the file cvt_tensorflow-1.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for cvt_tensorflow-1.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c66ca4ee364339c40fd6db954c84800fb184f536590db6780b5ae0ef0a7eee64
MD5 1d4533467ec981738022e2662f22cce4
BLAKE2b-256 81ddeb15627edce6475095949ee82b60fbd15d5085cf61ac3354304897dc52b5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page