TensorFlow 2.X reimplementation of Visual Attention Network, Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
Project description
VAN-Classification-TensorFlow
TensorFlow 2.X reimplementation of Visual Attention Network, Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
- Exact TensorFlow reimplementation of official PyTorch repo, including
timm
modules used by authors, preserving models and layers structure. - ImageNet pretrained weights ported from PyTorch official implementation.
Table of contents
Abstract
While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, the authors propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. The authors further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers (ViTs) and convolutional neural networks (CNNs) with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc.
Figure 1. Compare with different vision backbones on ImageNet-1K validation set.
Figure 2. Decomposition diagram of large-kernel convolution. A standard convolution can be decomposed into three parts: a depth-wise convolution (DW-Conv), a depth-wise dilation convolution (DW-D-Conv) and a 1×1 convolution (1×1 Conv).
Figure 3. The structure of different modules: (a) the proposed Large Kernel Attention (LKA); (b) non-attention module; (c) the self-attention module (d) a stage of our Visual Attention Network (VAN). CFF means convolutional feed-forward network. The difference between (a) and (b) is the element-wise multiply. It is worth noting that (c) is designed for 1D sequences.
Results
TensorFlow implementation and ImageNet ported weights have been compared to the official PyTorch implementation on ImageNet-V2 test set.
Models pre-trained on ImageNet-1K
Configuration | Resolution | Top-1 (Original) | Top-1 (Ported) | Top-5 (Original) | Top-5 (Ported) | #Params |
---|---|---|---|---|---|---|
VAN-B0 | 224x224 | 0.59 | 0.59 | 0.81 | 0.81 | 4.1M |
VAN-B1 | 224x224 | 0.64 | 0.64 | 0.84 | 0.84 | 13.9M |
VAN-B2 | 224x224 | 0.69 | 0.69 | 0.88 | 0.88 | 26.6M |
VAN-B3 | 224x224 | 0.71 | 0.71 | 0.89 | 0.89 | 44.8M |
Metrics difference: 0
.
Installation
- Install from PyPI.
pip install van-classification-tensorflow
- Install from GitHub.
pip install git+https://github.com/EMalagoli92/VAN-Classification-TensorFlow
- Clone the repo and install necessary packages.
git clone https://github.com/EMalagoli92/VAN-Classification-TensorFlow.git
pip install -r requirements.txt
Tested on Ubuntu 20.04.4 LTS x86_64, python 3.9.7.
Usage
- Define a custom VAN configuration.
from van_classification_tensorflow import VAN
# Define a custom VAN configuration
model = VAN(
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.0,
drop_path_rate=0.0,
depths=[3, 4, 6, 3],
num_stages=4,
include_top=True,
classifier_activation="softmax",
data_format="channels_last",
)
- Use a predefined VAN configuration.
from van_classification_tensorflow import VAN
model = VAN(
configuration="van_b0", data_format="channels_last", classifier_activation="softmax"
)
model.build((None, 224, 224, 3))
print(model.summary())
Model: "van_b0"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
patch_embed1 (OverlapPatchE ((None, 32, 56, 56), 4864
mbed) (),
())
block1/0 (Block) (None, 32, 56, 56) 25152
block1/1 (Block) (None, 32, 56, 56) 25152
block1/2 (Block) (None, 32, 56, 56) 25152
norm1 (LayerNorm_) (None, 3136, 32) 64
patch_embed2 (OverlapPatchE ((None, 64, 28, 28), 18752
mbed) (),
())
block2/0 (Block) (None, 64, 28, 28) 89216
block2/1 (Block) (None, 64, 28, 28) 89216
block2/2 (Block) (None, 64, 28, 28) 89216
norm2 (LayerNorm_) (None, 784, 64) 128
patch_embed3 (OverlapPatchE ((None, 160, 14, 14), 92960
mbed) (),
())
block3/0 (Block) (None, 160, 14, 14) 303040
block3/1 (Block) (None, 160, 14, 14) 303040
block3/2 (Block) (None, 160, 14, 14) 303040
block3/3 (Block) (None, 160, 14, 14) 303040
block3/4 (Block) (None, 160, 14, 14) 303040
norm3 (LayerNorm_) (None, 196, 160) 320
patch_embed4 (OverlapPatchE ((None, 256, 7, 7), 369920
mbed) (),
())
block4/0 (Block) (None, 256, 7, 7) 755200
block4/1 (Block) (None, 256, 7, 7) 755200
norm4 (LayerNorm_) (None, 49, 256) 512
head (Linear_) (None, 1000) 257000
pred (Activation) (None, 1000) 0
=================================================================
Total params: 4,113,224
Trainable params: 4,105,800
Non-trainable params: 7,424
_________________________________________________________________
- Train from scratch the model.
# Example
model.compile(
optimizer="sgd",
loss="sparse_categorical_crossentropy",
metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
model.fit(x, y)
- Use ported ImageNet pretrained weights.
# Example
from van_classification_tensorflow import VAN
model = VAN(
configuration="van_b1",
pretrained=True,
include_top=True,
classifier_activation="softmax",
)
y_pred = model(image)
- Use ported ImageNet pretrained weights for feature extraction (
include_top=False
).
import tensorflow as tf
from van_classification_tensorflow import VAN
# Get Features
inputs = tf.keras.layers.Input(shape=(224, 224, 3), dtype="float32")
features = VAN(configuration="van_b0", pretrained=True, include_top=False)(inputs)
# Custom classification
num_classes = 10
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(features)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
Acknowledgement
VAN-Classification (Official PyTorch implementation).
Citations
@article{guo2022visual,
title={Visual Attention Network},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2202.09741},
year={2022}
}
License
This work is made available under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file van_classification_tensorflow-1.0.3.tar.gz
.
File metadata
- Download URL: van_classification_tensorflow-1.0.3.tar.gz
- Upload date:
- Size: 18.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 63e0c680b4f6e0f0c6688037b21fe2fb6cdfd9d637a2d0c62541ac159409be57 |
|
MD5 | f4bedf234ae4394926b320f2044d4b53 |
|
BLAKE2b-256 | f495a24848f30ef7181fd746676e2bd87d1af701890c35da73e5fa686a1bbd97 |
File details
Details for the file van_classification_tensorflow-1.0.3-py3-none-any.whl
.
File metadata
- Download URL: van_classification_tensorflow-1.0.3-py3-none-any.whl
- Upload date:
- Size: 21.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8d0b939212ef90c1cb87146e3a6c0de21df52d7dbe94f9aca1ebf21c226ebd7b |
|
MD5 | 6ee111952bed3b4bcad9b1ea5aa88d7e |
|
BLAKE2b-256 | e052b2f96683bd7aa1d743271c818521d217d3353242b92e01df0ce304e041c1 |