Keras (TensorFlow v2) reimplementation of Visual Attention Network (VAN) model.
Project description
tfvan
Keras (TensorFlow v2) reimplementation of Visual Attention Network model. Based on Official Pytorch implementation.
Supports variable-shape inference. All weights are obtained by converting official checkpoints.
Installation
pip install tfvan
Examples
Default usage (without preprocessing):
from tfvan import VanTiny # + 3 other variants and input preprocessing
model = VanTiny() # by default will download imagenet-pretrained weights
model.compile(...)
model.fit(...)
Custom classification (with preprocessing):
from keras import layers, models
from tfvan import VanTiny, preprocess_input
inputs = layers.Input(shape=(224, 224, 3), dtype='uint8')
outputs = layers.Lambda(preprocess_input)(inputs)
outputs = VanTiny(include_top=False)(outputs)
outputs = layers.Dense(100, activation='softmax')(outputs)
model = models.Model(inputs=inputs, outputs=outputs)
model.compile(...)
model.fit(...)
Evaluation
For correctness, Tiny
and Small
models (original and ported) tested
with ImageNet-v2 test set.
import tensorflow as tf
import tensorflow_datasets as tfds
from tfvan import VanTiny, preprocess_input
def _prepare(example):
# Observation: +2.2% top1 accuracy in tiny model with antialias=True
image = tf.image.resize(example['image'], (248, 248), method=tf.image.ResizeMethod.BICUBIC)
image = tf.image.central_crop(image, 0.9)
image = preprocess_input(image)
return image, example['label']
imagenet2 = tfds.load('imagenet_v2', split='test', shuffle_files=True)
imagenet2 = imagenet2.map(_prepare, num_parallel_calls=tf.data.AUTOTUNE)
imagenet2 = imagenet2.batch(8)
model = VanTiny()
model.compile('sgd', 'sparse_categorical_crossentropy', ['accuracy', 'sparse_top_k_categorical_accuracy'])
history = model.evaluate(imagenet2)
print(history)
name | original acc@1 | ported acc@1 | original acc@5 | ported acc@5 |
---|---|---|---|---|
Tiny | 59.22 | 61.59 | 82.32 | 84.52 |
Small | 70.17 | 68.62 | 89.17 | 88.54 |
Citation
@article{guo2022visual,
title={Visual Attention Network},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2202.09741},
year={2022}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
tfvan-1.0.0.tar.gz
(10.6 kB
view details)
File details
Details for the file tfvan-1.0.0.tar.gz
.
File metadata
- Download URL: tfvan-1.0.0.tar.gz
- Upload date:
- Size: 10.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b614048cef6841ca09ab1d8d57a5092957c7bcd1e208164e879425a1adcb4e27 |
|
MD5 | 5ef29d4b748e3503533e921bac75f6d4 |
|
BLAKE2b-256 | a01c529b77f0699ead51351cffc3216936cbb377a3de849e098f88fa08c5f227 |