Skip to main content

Keras (TensorFlow v2) reimplementation of Visual Attention Network (VAN) model.

Project description

tfvan

Keras (TensorFlow v2) reimplementation of Visual Attention Network model. Based on Official Pytorch implementation.

Supports variable-shape inference. All weights are obtained by converting official checkpoints.

Installation

pip install tfvan

Examples

Default usage (without preprocessing):

from tfvan import VanTiny  # + 3 other variants and input preprocessing

model = VanTiny()  # by default will download imagenet-pretrained weights
model.compile(...)
model.fit(...)

Custom classification (with preprocessing):

from keras import layers, models
from tfvan import VanTiny, preprocess_input

inputs = layers.Input(shape=(224, 224, 3), dtype='uint8')
outputs = layers.Lambda(preprocess_input)(inputs)
outputs = VanTiny(include_top=False)(outputs)
outputs = layers.Dense(100, activation='softmax')(outputs)

model = models.Model(inputs=inputs, outputs=outputs)
model.compile(...)
model.fit(...)

Evaluation

For correctness, Tiny and Small models (original and ported) tested with ImageNet-v2 test set.

import tensorflow as tf
import tensorflow_datasets as tfds
from tfvan import VanTiny, preprocess_input


def _prepare(example):
    # Observation: +2.2% top1 accuracy in tiny model with antialias=True
    image = tf.image.resize(example['image'], (248, 248), method=tf.image.ResizeMethod.BICUBIC)
    image = tf.image.central_crop(image, 0.9)
    image = preprocess_input(image)

    return image, example['label']


imagenet2 = tfds.load('imagenet_v2', split='test', shuffle_files=True)
imagenet2 = imagenet2.map(_prepare, num_parallel_calls=tf.data.AUTOTUNE)
imagenet2 = imagenet2.batch(8)

model = VanTiny()
model.compile('sgd', 'sparse_categorical_crossentropy', ['accuracy', 'sparse_top_k_categorical_accuracy'])
history = model.evaluate(imagenet2)

print(history)
name original acc@1 ported acc@1 original acc@5 ported acc@5
Tiny 59.22 61.59 82.32 84.52
Small 70.17 68.62 89.17 88.54

Citation

@article{guo2022visual,
  title={Visual Attention Network},
  author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
  journal={arXiv preprint arXiv:2202.09741},
  year={2022}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tfvan-1.0.0.tar.gz (10.6 kB view details)

Uploaded Source

File details

Details for the file tfvan-1.0.0.tar.gz.

File metadata

  • Download URL: tfvan-1.0.0.tar.gz
  • Upload date:
  • Size: 10.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.6

File hashes

Hashes for tfvan-1.0.0.tar.gz
Algorithm Hash digest
SHA256 b614048cef6841ca09ab1d8d57a5092957c7bcd1e208164e879425a1adcb4e27
MD5 5ef29d4b748e3503533e921bac75f6d4
BLAKE2b-256 a01c529b77f0699ead51351cffc3216936cbb377a3de849e098f88fa08c5f227

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page