Skip to main content

Tensorflow keras computer vision attention models. Alias kecam. https://github.com/leondgarse/keras_cv_attention_models

Project description

Keras_cv_attention_models

  • coco_train_script.py is under testing

Roadmap and todo list


General Usage

Basic

  • Currently recommended TF version is tensorflow==2.8.0. Expecially for training or TFLite conversion.
  • Default import
    import os
    import tensorflow as tf
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from tensorflow import keras
    
  • Install as pip package:
    pip install -U keras-cv-attention-models
    # Or
    pip install -U git+https://github.com/leondgarse/keras_cv_attention_models
    
    Refer to each sub directory for detail usage.
  • Basic model prediction
    from keras_cv_attention_models import volo
    mm = volo.VOLO_d1(pretrained="imagenet")
    
    """ Run predict """
    import tensorflow as tf
    from tensorflow import keras
    from skimage.data import chelsea
    img = chelsea() # Chelsea the cat
    imm = keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
    pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
    pred = tf.nn.softmax(pred).numpy()  # If classifier activation is not softmax
    print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.9692954),
    #  ('n02123045', 'tabby', 0.020203391),
    #  ('n02123159', 'tiger_cat', 0.006867502),
    #  ('n02127052', 'lynx', 0.00017674894),
    #  ('n02123597', 'Siamese_cat', 4.9493494e-05)]
    
    Or just use model preset preprocess_input and decode_predictions
    from keras_cv_attention_models import coatnet
    from skimage.data import chelsea
    mm = coatnet.CoAtNet0()
    preds = mm(mm.preprocess_input(chelsea()))
    print(mm.decode_predictions(preds))
    # [[('n02124075', 'Egyptian_cat', 0.9653769), ('n02123159', 'tiger_cat', 0.018427467), ...]
    
  • Exclude model top layers by set num_classes=0
    from keras_cv_attention_models import resnest
    mm = resnest.ResNest50(num_classes=0)
    print(mm.output_shape)
    # (None, 7, 7, 2048)
    
  • Reload own model weights by set pretrained="xxx.h5". Better if reloading model with different input_shape and with weights shape not matching.
    import os
    from keras_cv_attention_models import coatnet
    pretrained = os.path.expanduser('~/.keras/models/coatnet0_224_imagenet.h5')
    mm = coatnet.CoAtNet1(input_shape=(384, 384, 3), pretrained=pretrained)
    
  • Alias name kecam can be used instead of keras_cv_attention_models. It's __init__.py only with one line from keras_cv_attention_models import *.
    import kecam
    mm = kecam.yolor.YOLOR_CSP()
    imm = kecam.test_images.dog_cat()
    preds = mm(mm.preprocess_input(imm))
    bboxs, lables, confidences = mm.decode_predictions(preds)[0]
    kecam.coco.show_image_with_bboxes(imm, bboxs, lables, confidences)
    
  • Calculate flops method from TF 2.0 Feature: Flops calculation #32809.
    from keras_cv_attention_models import coatnet, resnest, model_surgery
    
    model_surgery.get_flops(coatnet.CoAtNet0())
    # >>>> Flops: 4,221,908,559, GFlops: 4.2219G
    model_surgery.get_flops(resnest.ResNest50())
    # >>>> Flops: 5,378,399,992, GFlops: 5.3784G
    

Layers

  • attention_layers is __init__.py only, which imports core layers defined in model architectures. Like RelativePositionalEmbedding from botnet, outlook_attention from volo.
from keras_cv_attention_models import attention_layers
aa = attention_layers.RelativePositionalEmbedding()
print(f"{aa(tf.ones([1, 4, 14, 16, 256])).shape = }")
# aa(tf.ones([1, 4, 14, 16, 256])).shape = TensorShape([1, 4, 14, 16, 14, 16])

Model surgery

  • model_surgery including functions used to change model parameters after built.
from keras_cv_attention_models import model_surgery
mm = keras.applications.ResNet50()  # Trainable params: 25,583,592

# Replace all ReLU with PReLU. Trainable params: 25,606,312
mm = model_surgery.replace_ReLU(mm, target_activation='PReLU')

# Fuse conv and batch_norm layers. Trainable params: 25,553,192
mm = model_surgery.convert_to_fused_conv_bn_model(mm)

ImageNet training and evaluating

  • ImageNet contains more detail usage and some comparing results.
  • Init Imagenet dataset using tensorflow_datasets #9.
  • For custom dataset, recommending method is using tfds.load, refer Writing custom datasets and Creating private tensorflow_datasets from tfds #48 by @Medicmind.
  • custom_dataset_script.py can also be used creating a json format file, which can be used as --data_name xxx.json for training, detail usage can be found in Custom recognition dataset.
  • aotnet.AotNet50 default parameters set is a typical ResNet50 architecture with Conv2D use_bias=False and padding like PyTorch.
  • Default parameters for train_script.py is like A3 configuration from ResNet strikes back: An improved training procedure in timm with batch_size=256, input_shape=(160, 160).
    # `antialias` is default enabled for resize, can be turned off be set `--disable_antialias`.
    CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" ./train_script.py --seed 0 -s aotnet50
    
    # Evaluation using input_shape (224, 224).
    # `antialias` usage should be same with training.
    CUDA_VISIBLE_DEVICES='1' ./eval_script.py -m aotnet50_epoch_103_val_acc_0.7674.h5 -i 224 --central_crop 0.95
    # >>>> Accuracy top1: 0.78466 top5: 0.94088
    
  • Progressive training refer to PDF 2104.00298 EfficientNetV2: Smaller Models and Faster Training. AotNet50 A3 progressive input shapes 96 128 160:
    CUDA_VISIBLE_DEVICES='1' TF_XLA_FLAGS="--tf_xla_auto_jit=2" ./progressive_train_script.py \
    --progressive_epochs 33 66 -1 \
    --progressive_input_shapes 96 128 160 \
    --progressive_magnitudes 2 4 6 \
    -s aotnet50_progressive_3_lr_steps_100 --seed 0
    
    aotnet50_progressive_160
  • eval_script.py is used for evaluating model accuracy.
    # evaluating pretrained builtin model
    CUDA_VISIBLE_DEVICES='1' ./eval_script.py -m regnet.RegNetZD8
    # evaluating pretrained timm model
    CUDA_VISIBLE_DEVICES='1' ./eval_script.py -m timm.models.resmlp_12_224 --input_shape 224
    
    # evaluating specific h5 model
    CUDA_VISIBLE_DEVICES='1' ./eval_script.py -m checkpoints/xxx.h5
    # evaluating specific tflite model
    CUDA_VISIBLE_DEVICES='1' ./eval_script.py -m xxx.tflite
    

COCO training and evaluating

  • Currently still under testing.

  • COCO contains more detail usage.

  • custom_dataset_script.py can be used creating a json format file, which can be used as --data_name xxx.json for training, detail usage can be found in Custom detection dataset.

  • Default parameters for coco_train_script.py is EfficientDetD0 with input_shape=(256, 256, 3), batch_size=64, mosaic_mix_prob=0.5, freeze_backbone_epochs=32, total_epochs=105. Technically, it's any pyramid structure backbone + EfficientDet / YOLOX header / YOLOR header + anchor_free / yolor_anchors / efficientdet_anchors combination supported.

  • Currently 3 types anchors supported,

    • use_anchor_free_mode controls if using typical YOLOX anchor_free mode strategy.
    • use_yolor_anchors_mode controls if using yolor anchors.
    • Default is use_anchor_free_mode=False, use_yolor_anchors_mode=False, means using efficientdet preset anchors.
    anchors_mode use_object_scores num_anchors anchor_scale aspect_ratios num_scales grid_zero_start
    efficientdet False 9 4 [1, 2, 0.5] 3 False
    anchor_free True 1 1 [1] 1 True
    yolor_anchors True 3 None presets None offset=0.5
    # Default EfficientDetD0
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py
    # Default EfficientDetD0 using input_shape 512, optimizer adamw, freezing backbone 16 epochs, total 50 + 5 epochs
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py -i 512 -p adamw --freeze_backbone_epochs 16 --lr_decay_steps 50
    
    # EfficientNetV2B0 backbone + EfficientDetD0 detection header
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --backbone efficientnet.EfficientNetV2B0 --det_header efficientdet.EfficientDetD0
    # ResNest50 backbone + EfficientDetD0 header using yolox like anchor_free_mode
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --backbone resnest.ResNest50 --use_anchor_free_mode
    # ConvNeXtTiny backbone + EfficientDetD0 header using yolor anchors
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --backbone uniformer.UniformerSmall32 --use_yolor_anchors_mode
    
    # Typical YOLOXS with anchor_free_mode
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --det_header yolox.YOLOXS --use_anchor_free_mode
    # YOLOXS with efficientdet anchors
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --det_header yolox.YOLOXS
    # ConvNeXtTiny backbone + YOLOX header with yolor anchors
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --backbone coatnet.CoAtNet0 --det_header yolox.YOLOX --use_yolor_anchors_mode
    
    # Typical YOLOR_P6 with yolor anchors
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --det_header yolor.YOLOR_P6 --use_yolor_anchors_mode
    # YOLOR_P6 with anchor_free_mode
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --det_header yolor.YOLOR_P6 --use_anchor_free_mode
    # ConvNeXtTiny backbone + YOLOR header with efficientdet anchors
    CUDA_VISIBLE_DEVICES='0' ./coco_train_script.py --backbone convnext.ConvNeXtTiny --det_header yolor.YOLOR
    

    Note: COCO training still under testing, may change parameters and default behaviors. Take the risk if would like help developing.

  • coco_eval_script.py is used for evaluating model AP / AR on COCO validation set. It has a dependency pip install pycocotools which is not in package requirements. More usage can be found in COCO Evaluation.

    # resize method for EfficientDetD0 is bilinear w/o antialias
    CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m efficientdet.EfficientDetD0 --resize_method bilinear --disable_antialias
    # Specify --use_anchor_free_mode for YOLOX, and BGR input format
    CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m yolox.YOLOXTiny --use_anchor_free_mode --use_bgr_input --nms_method hard --nms_iou_or_sigma 0.65
    # Specify --use_yolor_anchors_mode for YOLOR.
    CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m yolor.YOLOR_CSP --use_yolor_anchors_mode --nms_method hard --nms_iou_or_sigma 0.65 \
    --nms_max_output_size 300 --nms_topk -1 --letterbox_pad 64 --input_shape 704
    
    # Specific h5 model
    CUDA_VISIBLE_DEVICES='1' ./coco_eval_script.py -m checkpoints/yoloxtiny_yolor_anchor.h5 --use_yolor_anchors_mode
    

Visualizing

  • Visualizing is for visualizing convnet filters or attention map scores.
  • make_and_apply_gradcam_heatmap is for Grad-CAM class activation visualization.
    from keras_cv_attention_models import visualizing, test_images, resnest
    mm = resnest.ResNest50()
    img = test_images.dog()
    superimposed_img, heatmap, preds = visualizing.make_and_apply_gradcam_heatmap(mm, img, layer_name="auto")
    
  • plot_attention_score_maps is model attention score maps visualization.
    from keras_cv_attention_models import visualizing, test_images, botnet
    img = test_images.dog()
    _ = visualizing.plot_attention_score_maps(botnet.BotNetSE33T(), img)
    

TFLite Conversion

  • Currently TFLite not supporting Conv2D with groups>1 / gelu / tf.image.extract_patches / tf.transpose with len(perm) > 4. Some operations could be supported in tf-nightly version. May try if encountering issue. More discussion can be found Converting a trained keras CV attention model to TFLite #17. Some speed testing results can be found How to speed up inference on a quantized model #44.
  • tf.nn.gelu(inputs, approximate=True) activation works for TFLite. Define model with activation="gelu/approximate" or activation="gelu/app" will set approximate=True for gelu. Should better decide before training, or there may be accuracy loss.
  • model_surgery.convert_groups_conv2d_2_split_conv2d converts model Conv2D with groups>1 layers to SplitConv using split -> conv -> concat:
    from keras_cv_attention_models import regnet, model_surgery
    from keras_cv_attention_models.imagenet import eval_func
    
    bb = regnet.RegNetZD32()
    mm = model_surgery.convert_groups_conv2d_2_split_conv2d(bb)  # converts all `Conv2D` using `groups` to `SplitConv2D`
    test_inputs = np.random.uniform(size=[1, *mm.input_shape[1:]])
    print(np.allclose(mm(test_inputs), bb(test_inputs)))
    # True
    
    converter = tf.lite.TFLiteConverter.from_keras_model(mm)
    open(mm.name + ".tflite", "wb").write(converter.convert())
    print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf(mm.name + '.tflite')(test_inputs), atol=1e-7))
    # True
    
  • model_surgery.convert_gelu_and_extract_patches_for_tflite converts model gelu activation to gelu approximate=True, and tf.image.extract_patches to a Conv2D version:
    from keras_cv_attention_models import cotnet, model_surgery
    from keras_cv_attention_models.imagenet import eval_func
    
    mm = cotnet.CotNetSE50D()
    mm = model_surgery.convert_groups_conv2d_2_split_conv2d(mm)
    mm = model_surgery.convert_gelu_and_extract_patches_for_tflite(mm)
    converter = tf.lite.TFLiteConverter.from_keras_model(mm)
    open(mm.name + ".tflite", "wb").write(converter.convert())
    test_inputs = np.random.uniform(size=[1, *mm.input_shape[1:]])
    print(np.allclose(mm(test_inputs), eval_func.TFLiteModelInterf(mm.name + '.tflite')(test_inputs), atol=1e-7))
    # True
    
  • model_surgery.prepare_for_tflite is just a combination of above 2 functions:
    from keras_cv_attention_models import beit, model_surgery
    
    mm = beit.BeitBasePatch16()
    mm = model_surgery.prepare_for_tflite(mm)
    converter = tf.lite.TFLiteConverter.from_keras_model(mm)
    open(mm.name + ".tflite", "wb").write(converter.convert())
    
  • Not supporting VOLO / HaloNet models converting, cause they need a longer tf.transpose perm.

Recognition Models

AotNet

  • Keras AotNet is just a ResNet / ResNetV2 like framework, that set parameters like attn_types and se_ratio and others, which is used to apply different types attention layer. Works like byoanet / byobnet from timm.
  • Default parameters set is a typical ResNet architecture with Conv2D use_bias=False and padding like PyTorch.
from keras_cv_attention_models import aotnet
# Mixing se and outlook and halo and mhsa and cot_attention, 21M parameters.
# 50 is just a picked number that larger than the relative `num_block`.
attn_types = [None, "outlook", ["bot", "halo"] * 50, "cot"],
se_ratio = [0.25, 0, 0, 0],
model = aotnet.AotNet50V2(attn_types=attn_types, se_ratio=se_ratio, stem_type="deep", strides=1)
model.summary()

BEIT

Model Params Image resolution Top1 Acc Download
BeitBasePatch16, 21k 86.53M 224 85.240 beit_base_patch16_224.h5
86.74M 384 86.808 beit_base_patch16_384.h5
BeitLargePatch16, 21k 304.43M 224 87.476 beit_large_patch16_224.h5
305.00M 384 88.382 beit_large_patch16_384.h5
305.67M 512 88.584 beit_large_patch16_512.h5

BotNet

Model Params Image resolution Top1 Acc Download
BotNet50 21M 224
BotNet101 41M 224
BotNet152 56M 224
BotNet26T 12.5M 256 79.246 botnet26t_256_imagenet.h5
BotNextECA26T 10.59M 256 79.270 botnext_eca26t_256_imagenet.h5
BotNetSE33T 13.7M 256 81.2 botnet_se33t_256_imagenet.h5

CMT

Model Params Image resolution Top1 Acc Download
CMTTiny, (Self trained 105 epochs) 9.5M 160 77.4
- 305 epochs 9.5M 160 78.8 cmt_tiny_160_imagenet
- evaluate 224 (not fine-tuned) 9.5M 224 80.1
CMTTiny, 1000 epochs 9.5M 160 79.2
CMTXS 15.2M 192 81.8
CMTSmall 25.1M 224 83.5
CMTBig 45.7M 256 84.5

CoaT

Model Params Image resolution Top1 Acc Download
CoaTLiteTiny 5.7M 224 77.5 coat_lite_tiny_imagenet.h5
CoaTLiteMini 11M 224 79.1 coat_lite_mini_imagenet.h5
CoaTLiteSmall 20M 224 81.9 coat_lite_small_imagenet.h5
CoaTTiny 5.5M 224 78.3 coat_tiny_imagenet.h5
CoaTMini 10M 224 81.0 coat_mini_imagenet.h5

CoAtNet

Model Params Image resolution Top1 Acc Download
CoAtNet0 (Self trained 105 epochs) 23.8M 160 80.50 coatnet0_160_imagenet.h5
- fine-tune 224, 37 epochs 23.8M 224 82.23 coatnet0_224_imagenet.h5
CoAtNet0 25M 224 81.6
CoAtNet0, Strided DConv 25M 224 82.0
CoAtNet1 42M 224 83.3
CoAtNet1, Strided DConv 42M 224 83.5
CoAtNet2 75M 224 84.1
CoAtNet2, Strided DConv 75M 224 84.1
CoAtNet2, ImageNet-21k pretrain 75M 224 87.1
CoAtNet3 168M 224 84.5
CoAtNet3, ImageNet-21k pretrain 168M 224 87.6
CoAtNet3, ImageNet-21k pretrain 168M 512 87.9
CoAtNet4, ImageNet-21k pretrain 275M 512 88.1
CoAtNet4, ImageNet-21K + PT-RA-E150 275M 512 88.56

JFT pre-trained models accuracy

Model Image resolution Reported Params self-defined Params Top1 Acc
CoAtNet3 384 168M 162.96M 88.52
CoAtNet3 512 168M 163.57M 88.81
CoAtNet4 512 275M 273.10M 89.11
CoAtNet5 512 688M 680.47M 89.77
CoAtNet6 512 1.47B 1.340B 90.45
CoAtNet7 512 2.44B 2.422B 90.88

ConvNeXt

Model Params Image resolution Top1 Acc Download
ConvNeXtTiny 28M 224 82.1 tiny_imagenet.h5
ConvNeXtSmall 50M 224 83.1 small_imagenet.h5
ConvNeXtBase 89M 224 83.8 base_224_imagenet.h5
ConvNeXtBase 89M 384 85.1 base_384_imagenet.h5
- ImageNet21k-ft1k 89M 224 85.8 base_224_21k.h5
- ImageNet21k-ft1k 89M 384 86.8 base_384_21k.h5
ConvNeXtLarge 198M 224 84.3 large_224_imagenet.h5
ConvNeXtLarge 198M 384 85.5 large_384_imagenet.h5
- ImageNet21k-ft1k 198M 224 86.6 large_224_21k.h5
- ImageNet21k-ft1k 198M 384 87.5 large_384_21k.h5
ConvNeXtXLarge, 21k 350M 224 87.0 xlarge_224_21k.h5
ConvNeXtXLarge, 21k 350M 384 87.8 xlarge_384_21k.h5

CoTNet

Model Params Image resolution FLOPs Top1 Acc Download
CotNet50 22.2M 224 3.3 81.3 cotnet50_224_imagenet.h5
CoTNeXt50 30.1M 224 4.3 82.1
CotNetSE50D 23.1M 224 4.1 81.6 cotnet_se50d_224_imagenet.h5
CotNet101 38.3M 224 6.1 82.8 cotnet101_224_imagenet.h5
CoTNeXt-101 53.4M 224 8.2 83.2
CotNetSE101D 40.9M 224 8.5 83.2 cotnet_se101d_224_imagenet.h5
CotNetSE152D 55.8M 224 17.0 84.0 cotnet_se152d_224_imagenet.h5
CotNetSE152D 55.8M 320 26.5 84.6 cotnet_se152d_320_imagenet.h5

EfficientNet

V2 Model Params Image resolution Top1 Acc Download
EfficientNetV2B0 7.1M 224 78.7 effv2b0-imagenet.h5
- ImageNet21k-ft1k 7.1M 224 77.55? effv2b0-21k-ft1k.h5
EfficientNetV2B1 8.1M 240 79.8 effv2b1-imagenet.h5
- ImageNet21k-ft1k 8.1M 240 79.03? effv2b1-21k-ft1k.h5
EfficientNetV2B2 10.1M 260 80.5 effv2b2-imagenet.h5
- ImageNet21k-ft1k 10.1M 260 79.48? effv2b2-21k-ft1k.h5
EfficientNetV2B3 14.4M 300 82.1 effv2b3-imagenet.h5
- ImageNet21k-ft1k 14.4M 300 82.46? effv2b3-21k-ft1k.h5
EfficientNetV2T 13.6M 320 82.5 effv2t-imagenet.h5
EfficientNetV2S 21.5M 384 83.9 effv2s-imagenet.h5
- ImageNet21k-ft1k 21.5M 384 84.9 effv2s-21k-ft1k.h5
EfficientNetV2M 54.1M 480 85.2 effv2m-imagenet.h5
- ImageNet21k-ft1k 54.1M 480 86.2 effv2m-21k-ft1k.h5
EfficientNetV2L 119.5M 480 85.7 effv2l-imagenet.h5
- ImageNet21k-ft1k 119.5M 480 86.9 effv2l-21k-ft1k.h5
EfficientNetV2XL, 21k-ft1k 206.8M 512 87.2 effv2xl-21k-ft1k.h5
V1 Model Params Image resolution Top1 Acc Download
EfficientNetV1B0 5.3M 224 77.6 effv1-b0-imagenet.h5
- NoisyStudent 5.3M 224 78.8 effv1-b0-noisy_student.h5
EfficientNetV1B1 7.8M 240 79.6 effv1-b1-imagenet.h5
- NoisyStudent 7.8M 240 81.5 effv1-b1-noisy_student.h5
EfficientNetV1B2 9.1M 260 80.5 effv1-b2-imagenet.h5
- NoisyStudent 9.1M 260 82.4 effv1-b2-noisy_student.h5
EfficientNetV1B3 12.2M 300 81.9 effv1-b3-imagenet.h5
- NoisyStudent 12.2M 300 84.1 effv1-b3-noisy_student.h5
EfficientNetV1B4 19.3M 380 83.3 effv1-b4-imagenet.h5
- NoisyStudent 19.3M 380 85.3 effv1-b4-noisy_student.h5
EfficientNetV1B5 30.4M 456 84.3 effv1-b5-imagenet.h5
- NoisyStudent 30.4M 456 86.1 effv1-b5-noisy_student.h5
EfficientNetV1B6 43.0M 528 84.8 effv1-b6-imagenet.h5
- NoisyStudent 43.0M 528 86.4 effv1-b6-noisy_student.h5
EfficientNetV1B7 66.3M 600 85.2 effv1-b7-imagenet.h5
- NoisyStudent 66.3M 600 86.9 effv1-b7-noisy_student.h5
EfficientNetV1L2, NoisyStudent 480.3M 800 88.4 effv1-l2-noisy_student.h5

FBNetV3

Model Params Image resolution Top1 Acc Download
FBNetV3B 5.57M 256 79.15 fbnetv3_b_imagenet.h5
FBNetV3D 10.31M 256 79.68 fbnetv3_d_imagenet.h5
FBNetV3G 16.62M 256 82.05 fbnetv3_g_imagenet.h5

GMLP

Model Params Image resolution Top1 Acc Download
GMLPTiny16 6M 224 72.3
GMLPS16 20M 224 79.6 gmlp_s16_imagenet.h5
GMLPB16 73M 224 81.6

HaloNet

Model Params Image resolution Top1 Acc Download
HaloNetH0 5.5M 256 77.9
HaloNetH1 8.1M 256 79.9
HaloNetH2 9.4M 256 80.4
HaloNetH3 11.8M 320 81.9
HaloNetH4 19.1M 384 83.3
- 21k 19.1M 384 85.5
HaloNetH5 30.7M 448 84.0
HaloNetH6 43.4M 512 84.4
HaloNetH7 67.4M 600 84.9
HaloNextECA26T 10.7M 256 79.50 halonext_eca26t_256_imagenet.h5
HaloNet26T 12.5M 256 79.13 halonet26t_256_imagenet.h5
HaloNetSE33T 13.7M 256 80.99 halonet_se33t_256_imagenet.h5
HaloRegNetZB 11.68M 224 81.042 haloregnetz_b_224_imagenet.h5
HaloNet50T 22.7M 256 81.70 halonet50t_256_imagenet.h5
HaloBotNet50T 22.6M 256 82.0 halobotnet50t_256_imagenet.h5

LCNet

Model Params Image resolution Top1 Acc Download
LCNet050 1.88M 224 63.10 lcnet_050_imagenet.h5
LCNet075 2.36M 224 68.82 lcnet_075_imagenet.h5
LCNet100 2.95M 224 72.10 lcnet_100_imagenet.h5
LCNet150 4.52M 224 73.71 lcnet_150_imagenet.h5
LCNet200 6.54M 224 75.18 lcnet_200_imagenet.h5
LCNet250 9.04M 224 76.60 lcnet_250_imagenet.h5

LeViT

Model Params Image resolution Top1 Acc Download
LeViT128S, distillation 7.8M 224 76.6 levit128s_imagenet.h5
LeViT128, distillation 9.2M 224 78.6 levit128_imagenet.h5
LeViT192, distillation 11M 224 80.0 levit192_imagenet.h5
LeViT256, distillation 19M 224 81.6 levit256_imagenet.h5
LeViT384, distillation 39M 224 82.6 levit384_imagenet.h5

MLP mixer

Model Params Top1 Acc ImageNet Imagenet21k ImageNet SAM
MLPMixerS32 19.1M 68.70
MLPMixerS16 18.5M 73.83
MLPMixerB32 60.3M 75.53 b32_imagenet_sam.h5
MLPMixerB16 59.9M 80.00 b16_imagenet.h5 b16_imagenet21k.h5 b16_imagenet_sam.h5
MLPMixerL32 206.9M 80.67
MLPMixerL16 208.2M 84.82 l16_imagenet.h5 l16_imagenet21k.h5
- input 448 208.2M 86.78
MLPMixerH14 432.3M 86.32
- input 448 432.3M 87.94

MobileNetV3

Model Params Image resolution Top1 Acc Download
MobileNetV3Small050 1.29M 224 57.89 small_050_imagenet.h5
MobileNetV3Small075 2.04M 224 65.24 small_075_imagenet.h5
MobileNetV3Small100 2.54M 224 67.66 small_100_imagenet.h5
MobileNetV3Large075 3.99M 224 73.44 large_075_imagenet.h5
MobileNetV3Large100 5.48M 224 75.77 large_100_imagenet.h5
- miil 5.48M 224 77.92 large_100_miil.h5

MobileViT

Model Params Image resolution Top1 Acc Download
MobileViT_XXS 1.3M 256 69.0 mobilevit_xxs_imagenet
MobileViT_XS 2.3M 256 74.7 mobilevit_xs_imagenet
MobileViT_S 5.6M 256 78.3 mobilevit_s_imagenet

NFNets

Model Params Image resolution Top1 Acc Download
NFNetL0 35.07M 288 82.75 nfnetl0_imagenet.h5
NFNetF0 71.5M 256 83.6 nfnetf0_imagenet.h5
NFNetF1 132.6M 320 84.7 nfnetf1_imagenet.h5
NFNetF2 193.8M 352 85.1 nfnetf2_imagenet.h5
NFNetF3 254.9M 416 85.7 nfnetf3_imagenet.h5
NFNetF4 316.1M 512 85.9 nfnetf4_imagenet.h5
NFNetF5 377.2M 544 86.0 nfnetf5_imagenet.h5
NFNetF6 SAM 438.4M 576 86.5 nfnetf6_imagenet.h5
NFNetF7 499.5M 608
ECA_NFNetL0 24.14M 288 82.58 eca_nfnetl0_imagenet.h5
ECA_NFNetL1 41.41M 320 84.01 eca_nfnetl1_imagenet.h5
ECA_NFNetL2 56.72M 384 84.70 eca_nfnetl2_imagenet.h5
ECA_NFNetL3 72.04M 448

RegNetY

Model Params Image resolution Top1 Acc Download
RegNetY040 20.65M 224 82.3 regnety_040_imagenet.h5
RegNetY064 30.58M 224 83.0 regnety_064_imagenet.h5
RegNetY080 39.18M 224 83.17 regnety_080_imagenet.h5
RegNetY160 83.59M 224 82.0 regnety_160_imagenet.h5
RegNetY320 145.05M 224 82.5 regnety_320_imagenet.h5

RegNetZ

Model Params Image resolution Top1 Acc Download
RegNetZB16 9.72M 224 79.868 regnetz_b16_imagenet.h5
RegNetZC16 13.46M 256 82.164 regnetz_c16_imagenet.h5
RegNetZC16_EVO 13.49M 256 81.9 regnetz_c16_evo_imagenet.h5
RegNetZD32 27.58M 256 83.422 regnetz_d32_imagenet.h5
RegNetZD8 23.37M 256 83.5 regnetz_d8_imagenet.h5
RegNetZD8_EVO 23.46M 256 83.42 regnetz_d8_evo_imagenet.h5
RegNetZE8 57.70M 256 84.5 regnetz_e8_imagenet.h5

ResMLP

Model Params Image resolution Top1 Acc Download
ResMLP12 15M 224 77.8 resmlp12_imagenet.h5
ResMLP24 30M 224 80.8 resmlp24_imagenet.h5
ResMLP36 116M 224 81.1 resmlp36_imagenet.h5
ResMLP_B24 129M 224 83.6 resmlp_b24_imagenet.h5
- imagenet22k 129M 224 84.4 resmlp_b24_imagenet22k.h5

ResNeSt

Model Params Image resolution Top1 Acc Download
resnest50 28M 224 81.03 resnest50.h5
resnest101 49M 256 82.83 resnest101.h5
resnest200 71M 320 83.84 resnest200.h5
resnest269 111M 416 84.54 resnest269.h5

ResNetD

Model Params Image resolution Top1 Acc Download
ResNet50D 25.58M 224 80.530 resnet50d.h5
ResNet101D 44.57M 224 83.022 resnet101d.h5
ResNet152D 60.21M 224 83.680 resnet152d.h5
ResNet200D 64.69M 224 83.962 resnet200d.h5

ResNetQ

Model Params Image resolution Top1 Acc Download
ResNet51Q 35.7M 224 82.36 resnet51q.h5

ResNeXt

Model Params Image resolution Top1 Acc Download
ResNeXt50 (32x4d) 25M 224 79.768 resnext50_imagenet.h5
- SWSL 25M 224 82.182 resnext50_swsl.h5
ResNeXt50D (32x4d + deep) 25M 224 79.676 resnext50d_imagenet.h5
ResNeXt101 (32x4d) 42M 224 80.334 resnext101_imagenet.h5
- SWSL 42M 224 83.230 resnext101_swsl.h5
ResNeXt101W (32x8d) 89M 224 79.308 resnext101_imagenet.h5
- SWSL 89M 224 84.284 resnext101w_swsl.h5

SwinTransformerV2

Model Params Image resolution Top1 Acc Download
SwinTransformerV2Tiny_ns 28.3M 224 81.8 v2_tiny_ns_224_imagenet.h5
SwinTransformerV2Small 49.7M 224 83.13 v2_small_224_imagenet.h5
SwinTransformerV2Base, 22k 87.9M 384 87.1
SwinTransformerV2Large, 22k 196.7M 384 87.7
SwinTransformerV2Giant, 22k+ext 2.60B 640 90.17

TinyNet

Model Params Image resolution Top1 Acc Download
TinyNetE 2.04M 106 59.86 tinynet_e_imagenet.h5
TinyNetD 2.34M 152 66.96 tinynet_d_imagenet.h5
TinyNetC 2.46M 184 71.23 tinynet_c_imagenet.h5
TinyNetB 3.73M 188 74.98 tinynet_b_imagenet.h5
TinyNetA 6.19M 192 77.65 tinynet_a_imagenet.h5

UniFormer

Model Params Image resolution Top1 Acc Download
UniformerSmall32 + TL 22M 224 83.4 small_32_224_token_label
UniformerSmall64 22M 224 82.9 small_64_imagenet
- Token Labeling 22M 224 83.4 small_64_token_label
UniformerSmallPlus32 24M 224 83.4 small_plus_32_imagenet
- Token Labeling 24M 224 83.9 small_plus_32_token_label
UniformerSmallPlus64 24M 224 83.4 small_plus_64_imagenet
- Token Labeling 24M 224 83.6 small_plus_64_token_label
UniformerBase32 + TL 50M 224 85.1 base_32_224_token_label
UniformerBase64 50M 224 83.8 base_64_imagenet
- Token Labeling 50M 224 84.8 base_64_224_token_label
UniformerLarge64 + TL 100M 224 85.6 large_64_224_token_label
UniformerLarge64 + TL 100M 384 86.3 large_64_384_token_label

VOLO

Model Params Image resolution Top1 Acc Download
VOLO_d1 27M 224 84.2 volo_d1_224_imagenet.h5
VOLO_d1 ↑384 27M 384 85.2 volo_d1_384_imagenet.h5
VOLO_d2 59M 224 85.2 volo_d2_224_imagenet.h5
VOLO_d2 ↑384 59M 384 86.0 volo_d2_384_imagenet.h5
VOLO_d3 86M 224 85.4 volo_d3_224_imagenet.h5
VOLO_d3 ↑448 86M 448 86.3 volo_d3_448_imagenet.h5
VOLO_d4 193M 224 85.7 volo_d4_224_imagenet.h5
VOLO_d4 ↑448 193M 448 86.8 volo_d4_448_imagenet.h5
VOLO_d5 296M 224 86.1 volo_d5_224_imagenet.h5
VOLO_d5 ↑448 296M 448 87.0 volo_d5_448_imagenet.h5
VOLO_d5 ↑512 296M 512 87.1 volo_d5_512_imagenet.h5

WaveMLP

Model Params Image resolution Top1 Acc Download
WaveMLP_T 17M 224 80.9 wavemlp_t_imagenet.h5
WaveMLP_S 30M 224 82.9 wavemlp_s_imagenet.h5
WaveMLP_M 44M 224 83.3 wavemlp_m_imagenet.h5
WaveMLP_B 63M 224 83.6

Detection Models

EfficientDet

Model Params Image resolution COCO val AP test AP Download
EfficientDetD0 3.9M 512 34.3 34.6 efficientdet_d0.h5
- Det-AdvProp 3.9M 512 35.1 35.3
EfficientDetD1 6.6M 640 40.2 40.5 efficientdet_d1.h5
- Det-AdvProp 6.6M 640 40.8 40.9
EfficientDetD2 8.1M 768 43.5 43.9 efficientdet_d2.h5
- Det-AdvProp 8.1M 768 44.3 44.3
EfficientDetD3 12.0M 896 46.8 47.2 efficientdet_d3.h5
- Det-AdvProp 12.0M 896 47.7 48.0
EfficientDetD4 20.7M 1024 49.3 49.7 efficientdet_d4.h5
- Det-AdvProp 20.7M 1024 50.4 50.4
EfficientDetD5 33.7M 1280 51.2 51.5 efficientdet_d5.h5
- Det-AdvProp 33.7M 1280 52.2 52.5
EfficientDetD6 51.9M 1280 52.1 52.6 efficientdet_d6.h5
EfficientDetD7 51.9M 1536 53.4 53.7 efficientdet_d7.h5
EfficientDetD7X 77.0M 1536 54.4 55.1 efficientdet_d7x.h5
EfficientDetLite0 3.2M 320 26.41 efficientdet_lite0.h5
EfficientDetLite1 4.2M 384 31.50 efficientdet_lite1.h5
EfficientDetLite2 5.3M 448 35.06 efficientdet_lite2.h5
EfficientDetLite3 8.4M 512 38.77 efficientdet_lite3.h5
EfficientDetLite3X 9.3M 640 42.64 efficientdet_lite3x.h5
EfficientDetLite4 15.1M 640 43.18 efficientdet_lite4.h5

YOLOR

Model Params Image resolution COCO val AP test AP Download
YOLOR_CSP 52.9M 640 50.0 52.8 yolor_csp_coco.h5
YOLOR_CSPX 99.8M 640 51.5 54.8 yolor_csp_x_coco.h5
YOLOR_P6 37.3M 1280 52.5 55.7 yolor_p6_coco.h5
YOLOR_W6 79.9M 1280 56.9 yolor_w6_coco.h5
YOLOR_E6 115.9M 1280 57.6 yolor_e6_coco.h5
YOLOR_D6 151.8M 1280 58.2 yolor_d6_coco.h5

YOLOX

Model Params Image resolution COCO val AP test AP Download
YOLOXNano 0.91M 416 25.8 yolox_nano_coco.h5
YOLOXTiny 5.06M 416 32.8 yolox_tiny_coco.h5
YOLOXS 9.0M 640 40.5 40.5 yolox_s_coco.h5
YOLOXM 25.3M 640 46.9 47.2 yolox_m_coco.h5
YOLOXL 54.2M 640 49.7 50.1 yolox_l_coco.h5
YOLOXX 99.1M 640 51.5 51.5 yolox_x_coco.h5

Other implemented tensorflow or keras models


Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-cv-attention-models-1.2.18.tar.gz (418.1 kB view hashes)

Uploaded Source

Built Distribution

keras_cv_attention_models-1.2.18-py3-none-any.whl (449.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page