Tensorflow keras computer vision attention models. https://github.com/leondgarse/keras_cv_attention_models
Project description
Keras_cv_attention_models
Roadmap and todo list
Usage
Basic Usage
- Install as pip package:
pip install -U keras-cv-attention-models # Or pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
Refer to each sub directory for detail usage. - Basic model prediction
from keras_cv_attention_models import volo mm = volo.VOLO_d1(pretrained="imagenet") """ Run predict """ import tensorflow as tf from tensorflow import keras from skimage.data import chelsea img = chelsea() # Chelsea the cat imm = keras.applications.imagenet_utils.preprocess_input(img, mode='torch') pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy() pred = tf.nn.softmax(pred).numpy() # If classifier activation is not softmax print(keras.applications.imagenet_utils.decode_predictions(pred)[0]) # [('n02124075', 'Egyptian_cat', 0.9692954), # ('n02123045', 'tabby', 0.020203391), # ('n02123159', 'tiger_cat', 0.006867502), # ('n02127052', 'lynx', 0.00017674894), # ('n02123597', 'Siamese_cat', 4.9493494e-05)]
- Exclude model top layers by set
num_classes=0from keras_cv_attention_models import resnest mm = resnest.ResNest50(num_classes=0) print(mm.output_shape) # (None, 7, 7, 2048)
Layers
- attention_layers is
__init__.pyonly, which imports core layers defined in model architectures. LikeRelativePositionalEmbeddingfrombotnet,outlook_attentionfromvolo.
from keras_cv_attention_models import attention_layers
aa = attention_layers.RelativePositionalEmbedding()
print(f"{aa(tf.ones([1, 4, 14, 16, 256])).shape = }")
# aa(tf.ones([1, 4, 14, 16, 256])).shape = TensorShape([1, 4, 14, 16, 14, 16])
Model surgery
- model_surgery including functions used to change model parameters after built.
from keras_cv_attention_models import model_surgery
# Replace all ReLU with PReLU
mm = model_surgery.replace_ReLU(keras.applications.ResNet50(), target_activation='PReLU')
ImageNet Training
- Init Imagenet dataset using tensorflow_datasets.
- Default params for
train_script.pyis likeA3configuration from ResNet strikes back: An improved training procedure in timm. Please notice that there's still a gap comparing timm result...
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" ./train_script.py -h
AotNet
- Keras AotNet is just a
ResNet/ResNetV2like framework, that set parameters likeattn_typesandse_ratioand others, which is used to apply different types attention layer. Works likebyoanet/byobnetfromtimm.
from keras_cv_attention_models import aotnet
# Mixing se and outlook and halo and mhsa and cot_attention, 21M parameters.
# 50 is just a picked number that larger than the relative `num_block`.
attn_types = [None, "outlook", ["bot", "halo"] * 50, "cot"],
se_ratio = [0.25, 0, 0, 0],
model = aotnet.AotNet50V2(attn_types=attn_types, se_ratio=se_ratio, stem_type="deep", strides=1)
model.summary()
BEIT
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| BeitBasePatch16 | 86.53M | 224 | 85.240 | beit_base_patch16_224.h5 |
| 86.74M | 384 | 86.808 | beit_base_patch16_384.h5 | |
| BeitLargePatch16 | 304.43M | 224 | 87.476 | beit_large_patch16_224.h5 |
| 305.00M | 384 | 88.382 | beit_large_patch16_384.h5 | |
| 305.67M | 512 | 88.584 | beit_large_patch16_512.h5 |
BotNet
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| BotNet50 | 21M | 224 | ||
| BotNet101 | 41M | 224 | ||
| BotNet152 | 56M | 224 | ||
| BotNet26T | 12.5M | 256 | 79.246 | botnet26t_imagenet.h5 |
| BotNextECA26T | 10.59M | 256 | 79.270 | botnext_eca26t_imagenet.h5 |
CMT
| Model | Params | Image resolution | Top1 Acc |
|---|---|---|---|
| CMTTiny | 9.5M | 160 | 79.2 |
| CMTXS | 15.2M | 192 | 81.8 |
| CMTSmall | 25.1M | 224 | 83.5 |
| CMTBig | 45.7M | 256 | 84.5 |
CoaT
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| CoaTLiteTiny | 5.7M | 224 | 77.5 | coat_lite_tiny_imagenet.h5 |
| CoaTLiteMini | 11M | 224 | 79.1 | coat_lite_mini_imagenet.h5 |
| CoaTLiteSmall | 20M | 224 | 81.9 | coat_lite_small_imagenet.h5 |
| CoaTTiny | 5.5M | 224 | 78.3 | coat_tiny_imagenet.h5 |
| CoaTMini | 10M | 224 | 81.0 | coat_mini_imagenet.h5 |
CoAtNet
| Model | Params | Image resolution | Top1 Acc |
|---|---|---|---|
| CoAtNet-0 | 25M | 224 | 81.6 |
| CoAtNet-1 | 42M | 224 | 83.3 |
| CoAtNet-2 | 75M | 224 | 84.1 |
| CoAtNet-2, ImageNet-21k pretrain | 75M | 224 | 87.1 |
| CoAtNet-3 | 168M | 224 | 84.5 |
| CoAtNet-3, ImageNet-21k pretrain | 168M | 224 | 87.6 |
| CoAtNet-3, ImageNet-21k pretrain | 168M | 512 | 87.9 |
| CoAtNet-4, ImageNet-21k pretrain | 275M | 512 | 88.1 |
| CoAtNet-4, ImageNet-21K + PT-RA-E150 | 275M | 512 | 88.56 |
JFT pre-trained models accuracy
| Model | Image resolution | Reported Params | self-defined Params | Top1 Acc |
|---|---|---|---|---|
| CoAtNet3 | 384 | 168M | 162.96M | 88.52 |
| CoAtNet3 | 512 | 168M | 163.57M | 88.81 |
| CoAtNet4 | 512 | 275M | 273.10M | 89.11 |
| CoAtNet5 | 512 | 688M | 680.47M | 89.77 |
| CoAtNet6 | 512 | 1.47B | 1.340B | 90.45 |
| CoAtNet7 | 512 | 2.44B | 2.422B | 90.88 |
CoTNet
| Model | Params | Image resolution | FLOPs | Top1 Acc | Download |
|---|---|---|---|---|---|
| CotNet50 | 22.2M | 224 | 3.3 | 81.3 | cotnet50_224.h5 |
| CoTNeXt50 | 30.1M | 224 | 4.3 | 82.1 | |
| CotNetSE50D | 23.1M | 224 | 4.1 | 81.6 | cotnet_se50d_224.h5 |
| CotNet101 | 38.3M | 224 | 6.1 | 82.8 | cotnet101_224.h5 |
| CoTNeXt-101 | 53.4M | 224 | 8.2 | 83.2 | |
| CotNetSE101D | 40.9M | 224 | 8.5 | 83.2 | cotnet_se101d_224.h5 |
| CotNetSE152D | 55.8M | 224 | 17.0 | 84.0 | cotnet_se152d_224.h5 |
| CotNetSE152D | 55.8M | 320 | 26.5 | 84.6 | cotnet_se152d_320.h5 |
GMLP
- Keras GMLP includes implementation of PDF 2105.08050 Pay Attention to MLPs.
| Model | Params | Image resolution | Top1 Acc | ImageNet |
|---|---|---|---|---|
| GMLPTiny16 | 6M | 224 | 72.3 | |
| GMLPS16 | 20M | 224 | 79.6 | gmlp_s16_imagenet.h5 |
| GMLPB16 | 73M | 224 | 81.6 |
HaloNet
- Keras HaloNet is for PDF 2103.12731 Scaling Local Self-Attention for Parameter Efficient Visual Backbones.
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| HaloNetH0 | 5.5M | 256 | 77.9 | |
| HaloNetH1 | 8.1M | 256 | 79.9 | |
| HaloNetH2 | 9.4M | 256 | 80.4 | |
| HaloNetH3 | 11.8M | 320 | 81.9 | |
| HaloNetH4 | 19.1M | 384 | 83.3 | |
| - 21k | 19.1M | 384 | 85.5 | |
| HaloNetH5 | 30.7M | 448 | 84.0 | |
| HaloNetH6 | 43.4M | 512 | 84.4 | |
| HaloNetH7 | 67.4M | 600 | 84.9 | |
| HaloNextECA26T | 10.7M | 256 | 78.84 | halonext_eca26t_imagenet.h5 |
| HaloNet26T | 12.5M | 256 | 79.13 | halonet26t_imagenet.h5 |
| HaloNetSE33T | 13.7M | 256 | 80.99 | halonet_se33t_imagenet.h5 |
| HaloRegNetZB | 11.68M | 224 | 81.042 | haloregnetz_b_imagenet.h5 |
| HaloNet50T | 22.7M | 256 | 81.35 | halonet50t_imagenet.h5 |
LeViT
- Keras LeViT is for PDF 2104.01136 LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference.
| Model | Params | Image resolution | Top1 Acc | ImageNet |
|---|---|---|---|---|
| LeViT128S | 7.8M | 224 | 76.6 | levit128s_imagenet.h5 |
| LeViT128 | 9.2M | 224 | 78.6 | levit128_imagenet.h5 |
| LeViT192 | 11M | 224 | 80.0 | levit192_imagenet.h5 |
| LeViT256 | 19M | 224 | 81.6 | levit256_imagenet.h5 |
| LeViT384 | 39M | 224 | 82.6 | levit384_imagenet.h5 |
MLP mixer
- Keras MLP mixer includes implementation of PDF 2105.01601 MLP-Mixer: An all-MLP Architecture for Vision.
- Models
Top1 AccisPre-trained on JFT-300Mmodel accuray onImageNet 1Kfrom paper.
| Model | Params | Top1 Acc | ImageNet | Imagenet21k | ImageNet SAM |
|---|---|---|---|---|---|
| MLPMixerS32 | 19.1M | 68.70 | |||
| MLPMixerS16 | 18.5M | 73.83 | |||
| MLPMixerB32 | 60.3M | 75.53 | b32_imagenet_sam.h5 | ||
| MLPMixerB16 | 59.9M | 80.00 | b16_imagenet.h5 | b16_imagenet21k.h5 | b16_imagenet_sam.h5 |
| MLPMixerL32 | 206.9M | 80.67 | |||
| MLPMixerL16 | 208.2M | 84.82 | l16_imagenet.h5 | l16_imagenet21k.h5 | |
| - input 448 | 208.2M | 86.78 | |||
| MLPMixerH14 | 432.3M | 86.32 | |||
| - input 448 | 432.3M | 87.94 |
NFNets
- Keras NFNets is for PDF 2102.06171 High-Performance Large-Scale Image Recognition Without Normalization.
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| NFNetL0 | 35.07M | 288 | 82.75 | nfnetl0_imagenet.h5 |
| NFNetF0 | 71.5M | 256 | 83.6 | nfnetf0_imagenet.h5 |
| NFNetF1 | 132.6M | 320 | 84.7 | nfnetf1_imagenet.h5 |
| NFNetF2 | 193.8M | 352 | 85.1 | nfnetf2_imagenet.h5 |
| NFNetF3 | 254.9M | 416 | 85.7 | nfnetf3_imagenet.h5 |
| NFNetF4 | 316.1M | 512 | 85.9 | nfnetf4_imagenet.h5 |
| NFNetF5 | 377.2M | 544 | 86.0 | nfnetf5_imagenet.h5 |
| NFNetF6 SAM | 438.4M | 576 | 86.5 | nfnetf6_imagenet.h5 |
| NFNetF7 | 499.5M | 608 | ||
| ECA_NFNetL0 | 24.14M | 288 | 82.58 | eca_nfnetl0_imagenet.h5 |
| ECA_NFNetL1 | 41.41M | 320 | 84.01 | eca_nfnetl1_imagenet.h5 |
| ECA_NFNetL2 | 56.72M | 384 | 84.70 | eca_nfnetl2_imagenet.h5 |
| ECA_NFNetL3 | 72.04M | 448 |
RegNetY
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| RegNetY040 | 20.65M | 224 | 81.5 | regnety_040_imagenet.h5 |
| RegNetY080 | 39.18M | 224 | 82.2 | regnety_080_imagenet.h5 |
| RegNetY160 | 83.59M | 224 | 82.0 | regnety_160_imagenet.h5 |
| RegNetY320 | 145.05M | 224 | 82.5 | regnety_320_imagenet.h5 |
RegNetZ
- Keras RegNetZ includes implementation of Github timm/models/byobnet.py.
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| RegNetZB | 9.72M | 224 | 79.868 | regnetz_b_imagenet.h5 |
| RegNetZC | 13.46M | 256 | 82.164 | regnetz_c_imagenet.h5 |
| RegNetZD | 27.58M | 256 | 83.422 | regnetz_d_imagenet.h5 |
ResMLP
- Keras ResMLP includes implementation of PDF 2105.03404 ResMLP: Feedforward networks for image classification with data-efficient training
| Model | Params | Image resolution | Top1 Acc | ImageNet |
|---|---|---|---|---|
| ResMLP12 | 15M | 224 | 77.8 | resmlp12_imagenet.h5 |
| ResMLP24 | 30M | 224 | 80.8 | resmlp24_imagenet.h5 |
| ResMLP36 | 116M | 224 | 81.1 | resmlp36_imagenet.h5 |
| ResMLP_B24 | 129M | 224 | 83.6 | resmlp_b24_imagenet.h5 |
| - imagenet22k | 129M | 224 | 84.4 | resmlp_b24_imagenet22k.h5 |
ResNeSt
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| resnest50 | 28M | 224 | 81.03 | resnest50.h5 |
| resnest101 | 49M | 256 | 82.83 | resnest101.h5 |
| resnest200 | 71M | 320 | 83.84 | resnest200.h5 |
| resnest269 | 111M | 416 | 84.54 | resnest269.h5 |
ResNetD
- Keras ResNetD includes implementation of PDF 1812.01187 Bag of Tricks for Image Classification with Convolutional Neural Networks
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| ResNet50D | 25.58M | 224 | 80.530 | resnet50d.h5 |
| ResNet101D | 44.57M | 224 | 83.022 | resnet101d.h5 |
| ResNet152D | 60.21M | 224 | 83.680 | resnet152d.h5 |
| ResNet200D | 64.69 | 224 | 83.962 | resnet200d.h5 |
ResNetQ
- Keras ResNetQ includes implementation of Github timm/models/resnet.py
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| ResNet51Q | 35.7M | 224 | 82.36 | resnet51q.h5 |
ResNeXt
- Keras ResNeXt includes implementation of PDF 1611.05431 Aggregated Residual Transformations for Deep Neural Networks
SWSLmeansSemi-Weakly Supervised ResNe*tfrom Github facebookresearch/semi-supervised-ImageNet1K-models. Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| ResNeXt50 (32x4d) | 25M | 224 | 79.768 | resnext50_imagenet.h5 |
| - SWSL | 25M | 224 | 82.182 | resnext50_swsl.h5 |
| ResNeXt50D (32x4d + deep) | 25M | 224 | 79.676 | resnext50d_imagenet.h5 |
| ResNeXt101 (32x4d) | 42M | 224 | 80.334 | resnext101_imagenet.h5 |
| - SWSL | 42M | 224 | 83.230 | resnext101_swsl.h5 |
| ResNeXt101W (32x8d) | 89M | 224 | 79.308 | resnext101_imagenet.h5 |
| - SWSL | 89M | 224 | 84.284 | resnext101w_swsl.h5 |
VOLO
| Model | Params | Image resolution | Top1 Acc | Download |
|---|---|---|---|---|
| volo_d1 | 27M | 224 | 84.2 | volo_d1_224.h5 |
| volo_d1 ↑384 | 27M | 384 | 85.2 | volo_d1_384.h5 |
| volo_d2 | 59M | 224 | 85.2 | volo_d2_224.h5 |
| volo_d2 ↑384 | 59M | 384 | 86.0 | volo_d2_384.h5 |
| volo_d3 | 86M | 224 | 85.4 | volo_d3_224.h5 |
| volo_d3 ↑448 | 86M | 448 | 86.3 | volo_d3_448.h5 |
| volo_d4 | 193M | 224 | 85.7 | volo_d4_224.h5 |
| volo_d4 ↑448 | 193M | 448 | 86.8 | volo_d4_448.h5 |
| volo_d5 | 296M | 224 | 86.1 | volo_d5_224.h5 |
| volo_d5 ↑448 | 296M | 448 | 87.0 | volo_d5_448.h5 |
| volo_d5 ↑512 | 296M | 512 | 87.1 | volo_d5_512.h5 |
Other implemented tensorflow or keras models
- Github faustomorales/vit-keras
- Github rishigami/Swin-Transformer-TF
- Github tensorflow/resnet_rs
- Github google-research/big_transfer
- perceiver_image_classification
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file keras-cv-attention-models-1.1.1.tar.gz.
File metadata
- Download URL: keras-cv-attention-models-1.1.1.tar.gz
- Upload date:
- Size: 118.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a9870b382805b3a9162b001b3c4c4836751b496536d30fca02af2a2cfbba17c4
|
|
| MD5 |
c697ccff94725c1bb34da486b04d4331
|
|
| BLAKE2b-256 |
bbe53b9fb95715233110cc08ea68278353ee6cdcc4943de8a41606b09a5fda1c
|
File details
Details for the file keras_cv_attention_models-1.1.1-py3-none-any.whl.
File metadata
- Download URL: keras_cv_attention_models-1.1.1-py3-none-any.whl
- Upload date:
- Size: 134.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bc1ed80bcacc884f63c94589942264aff6ae58d5721d2bfa0fab1ca633176ca1
|
|
| MD5 |
b129c7258bef46c68d70a232d59c7b54
|
|
| BLAKE2b-256 |
00a00e4de495fdec03ac52e5a707b1ca9781b8f85539b52ef93ef47c2c134c7c
|