Skip to main content

tensorflow keras computer vision attention models

Project description

Keras_cv_attention_models


Usage

Basic Usage

  • Install as pip package:
    pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
    
    Refer to each sub directory for detail usage.
  • Basic model prediction
    from keras_cv_attention_models import volo
    mm = volo.VOLO_d1(pretrained="imagenet")
    
    """ Run predict """
    import tensorflow as tf
    from tensorflow import keras
    from skimage.data import chelsea
    img = chelsea() # Chelsea the cat
    imm = keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
    pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
    pred = tf.nn.softmax(pred).numpy()  # If classifier activation is not softmax
    print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.9692954),
    #  ('n02123045', 'tabby', 0.020203391),
    #  ('n02123159', 'tiger_cat', 0.006867502),
    #  ('n02127052', 'lynx', 0.00017674894),
    #  ('n02123597', 'Siamese_cat', 4.9493494e-05)]
    
  • Exclude model top layers by set num_classes=0
    from keras_cv_attention_models import resnest
    mm = resnest.ResNest50(num_classes=0)
    print(mm.output_shape)
    # (None, 7, 7, 2048)
    

Layers

  • attention_layers is __init__.py only, which imports core layers defined in model architectures. Like MHSAWithPositionEmbedding from botnet, HaloAttention from halonet.
from keras_cv_attention_models import attention_layers
aa = attention_layers.MHSAWithPositionEmbedding(num_heads=4, key_dim=128, relative=True)
print(f"{aa(tf.ones([1, 14, 16, 256])).shape = }")
# aa(tf.ones([1, 14, 16, 256])).shape = TensorShape([1, 14, 16, 512])

Model surgery

  • model_surgery including functions used to change model parameters after built.
from keras_cv_attention_models import model_surgery
# Replace all ReLU with PReLU
mm = model_surgery.replace_ReLU(keras.applications.ResNet50(), target_activation='PReLU')

AotNet

  • Keras AotNet is just a ResNet / ResNetV2 like framework, that set parameters like attn_types and se_ratio and others, which is used to apply different types attention layer.
    # Mixing se and outlook and halo and mhsa and cot_attention, 21M parameters
    # 50 is just a picked number that larger than the relative `num_block`
    from keras_cv_attention_models import aotnet
    attn_types = [None, "outlook", ["mhsa", "halo"] * 50, "cot"]
    se_ratio = [0.25, 0, 0, 0]
    mm = aotnet.AotNet50V2(attn_types=attn_types, se_ratio=se_ratio, deep_stem=True, strides=1)
    

ResNetD

Model Params Image resolution Top1 Acc Download
ResNet50D 25.58M 224 80.530 resnet50d.h5
ResNet101D 44.57M 224 83.022 resnet101d.h5
ResNet152D 60.21M 224 83.680 resnet152d.h5
ResNet200D 64.69 224 83.962 resnet200d.h5

ResNeXt

Model Params Image resolution Top1 Acc Download
ResNeXt50 (32x4d) 25M 224 79.768 resnext50_imagenet.h5
- SWSL 25M 224 82.182 resnext50_swsl.h5
ResNeXt50D (32x4d + deep) 25M 224 79.676 resnext50d_imagenet.h5
ResNeXt101 (32x4d) 42M 224 80.334 resnext101_imagenet.h5
- SWSL 42M 224 83.230 resnext101_swsl.h5
ResNeXt101W (32x8d) 89M 224 79.308 resnext101_imagenet.h5
- SWSL 89M 224 84.284 resnext101w_swsl.h5

ResNetQ

Model Params Image resolution Top1 Acc Download
ResNet51Q 35.7M 224 82.36 resnet51q.h5

BotNet

Model Params Image resolution Top1 Acc Download
botnet50 21M 224 77.604 botnet50_imagenet.h5
botnet101 41M 224
botnet152 56M 224

VOLO

Model Params Image resolution Top1 Acc Download
volo_d1 27M 224 84.2 volo_d1_224.h5
volo_d1 ↑384 27M 384 85.2 volo_d1_384.h5
volo_d2 59M 224 85.2 volo_d2_224.h5
volo_d2 ↑384 59M 384 86.0 volo_d2_384.h5
volo_d3 86M 224 85.4 volo_d3_224.h5
volo_d3 ↑448 86M 448 86.3 volo_d3_448.h5
volo_d4 193M 224 85.7 volo_d4_224.h5
volo_d4 ↑448 193M 448 86.8 volo_d4_448.h5
volo_d5 296M 224 86.1 volo_d5_224.h5
volo_d5 ↑448 296M 448 87.0 volo_d5_448.h5
volo_d5 ↑512 296M 512 87.1 volo_d5_512.h5

ResNeSt

Model Params Image resolution Top1 Acc Download
resnest50 28M 224 81.03 resnest50.h5
resnest101 49M 256 82.83 resnest101.h5
resnest200 71M 320 83.84 resnest200.h5
resnest269 111M 416 84.54 resnest269.h5

HaloNet

Model Params Image resolution Top1 Acc
HaloNetH0 6.6M 256
HaloNetH1 9.1M 256
HaloNetH2 10.3M 256
HaloNetH3 12.5M 320
HaloNetH4 19.5M 384 85.5
HaloNetH5 31.6M 448
HaloNetH6 44.3M 512
HaloNetH7 67.9M 640

CoTNet

Model Params Image resolution FLOPs Top1 Acc Download
CoTNet-50 22.2M 224 3.3 81.3 cotnet50_224.h5
CoTNeXt-50 30.1M 224 4.3 82.1
SE-CoTNetD-50 23.1M 224 4.1 81.6 se_cotnetd50_224.h5
CoTNet-101 38.3M 224 6.1 82.8 cotnet101_224.h5
CoTNeXt-101 53.4M 224 8.2 83.2
SE-CoTNetD-101 40.9M 224 8.5 83.2 se_cotnetd101_224.h5
SE-CoTNetD-152 55.8M 224 17.0 84.0 se_cotnetd152_224.h5
SE-CoTNetD-152 55.8M 320 26.5 84.6 se_cotnetd152_320.h5

CoAtNet

Model Params Image resolution Top1 Acc
CoAtNet-0 25M 224 81.6
CoAtNet-1 42M 224 83.3
CoAtNet-2 75M 224 84.1
CoAtNet-2, ImageNet-21k pretrain 75M 224 87.1
CoAtNet-3 168M 224 84.5
CoAtNet-3, ImageNet-21k pretrain 168M 224 87.6
CoAtNet-3, ImageNet-21k pretrain 168M 512 87.9
CoAtNet-4, ImageNet-21k pretrain 275M 512 88.1
CoAtNet-4, ImageNet-21K + PT-RA-E150 275M 512 88.56

CoaT

Model Params Image resolution Top1 Acc Download
CoaTLiteTiny 5.7M 224 77.5 coat_lite_tiny_imagenet.h5
CoaTLiteMini 11M 224 79.1 coat_lite_mini_imagenet.h5
CoaTLiteSmall 20M 224 81.9 coat_lite_small_imagenet.h5
CoaTTiny 5.5M 224 78.3 coat_tiny_imagenet.h5
CoaTMini 10M 224 81.0 coat_mini_imagenet.h5

MLP mixer

Model Params Top1 Acc ImageNet Imagenet21k ImageNet SAM
MLPMixerS32 19.1M 68.70
MLPMixerS16 18.5M 73.83
MLPMixerB32 60.3M 75.53 b32_imagenet_sam.h5
MLPMixerB16 59.9M 80.00 b16_imagenet.h5 b16_imagenet21k.h5 b16_imagenet_sam.h5
MLPMixerL32 206.9M 80.67
MLPMixerL16 208.2M 84.82 l16_imagenet.h5 l16_imagenet21k.h5
- input 448 208.2M 86.78
MLPMixerH14 432.3M 86.32
- input 448 432.3M 87.94

ResMLP

Model Params Image resolution Top1 Acc ImageNet
ResMLP12 15M 224 77.8 resmlp12_imagenet.h5
ResMLP24 30M 224 80.8 resmlp24_imagenet.h5
ResMLP36 116M 224 81.1 resmlp36_imagenet.h5
ResMLP_B24 129M 224 83.6 resmlp_b24_imagenet.h5
- imagenet22k 129M 224 84.4 resmlp_b24_imagenet22k.h5

GMLP

Model Params Image resolution Top1 Acc ImageNet
GMLPTiny16 6M 224 72.3
GMLPS16 20M 224 79.6 gmlp_s16_imagenet.h5
GMLPB16 73M 224 81.6

LeViT

Model Params Image resolution Top1 Acc ImageNet
LeViT128S 7.8M 224 76.6 levit128s_imagenet.h5
LeViT128 9.2M 224 78.6 levit128_imagenet.h5
LeViT192 11M 224 80.0 levit192_imagenet.h5
LeViT256 19M 224 81.6 levit256_imagenet.h5
LeViT384 39M 224 82.6 levit384_imagenet.h5

Other implemented keras models


Project details


Release history Release notifications | RSS feed

This version

1.0.3

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-cv-attention-models-1.0.3.tar.gz (67.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

keras_cv_attention_models-1.0.3-py3-none-any.whl (84.8 kB view details)

Uploaded Python 3

File details

Details for the file keras-cv-attention-models-1.0.3.tar.gz.

File metadata

  • Download URL: keras-cv-attention-models-1.0.3.tar.gz
  • Upload date:
  • Size: 67.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.7

File hashes

Hashes for keras-cv-attention-models-1.0.3.tar.gz
Algorithm Hash digest
SHA256 a31b37abf4dab399fffc2d50eed0a0b98e9985103b36a6ae64b6cf71bc53c521
MD5 4dda481ffd3cadd82c2e4110bb2fc6d8
BLAKE2b-256 c898f08d145b11d29de609b817db7a981c90eed7d2521d433dbab83e1a3697d5

See more details on using hashes here.

File details

Details for the file keras_cv_attention_models-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: keras_cv_attention_models-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 84.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.7

File hashes

Hashes for keras_cv_attention_models-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 148c6fe06359602de7f555627d068b59e4c01090e84489f12c7fb33f0a29f241
MD5 608103e8cf7b7dbc8c26d53cfcc47529
BLAKE2b-256 268096a2aa6a6708cb84fb3f172b7acfa9bb85db462ad99aefd141d2409ee7a4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page