Skip to main content

Implementation of ViT model based on tensorflow

Project description

vit_tf2

This is a package that implements the ViT model based on Tensorflow. The ViT was proposed in the paper "An image is worth 16x16 words: transformers for image recognition at scale". This package uses pre trained weights on the imagenet21K and imagenet2012 datasets, which are in. npz format.

◈ Preconditions

  • Python >= 3.7

  • Tensorflow >= 2.9

Q1: What can you do with this package?

  • Build a pre trained standard specification ViT model.

  • Customize and build any specification ViT model to suit your task.

Q2: How to build a pre trained ViT?

  1. Quickly build a pre trained ViTB16

    from vit_tf2.vit import ViT_B16
    vit = ViT_B16()
    

    The pre trained ViT has 4 configurations: ViT_B16, ViT_B32, ViT_L16 and ViT_L32.

    config patch size hiddem dim mlp dim attention heads encoder depth
    ViT_B16 16×16 768 3072 12 12
    ViT_B32 32×32 768 3072 12 12
    ViT_L16 16×16 1024 4096 16 24
    ViT_L32 32×32 1024 4096 16 24

    The "imagenet21k" and "imagenet21k+imagenet2012" are slightly different, as shown in the table below.

    dataset image size classes pre logits known labels
    imagenet21k 224 21843 True False
    imagenet21k+imagenet2012 384 1000 False True
  2. Build ViTB16 with differernt pre trained weights.

    from vit_tf2.vit import ViT_B16
    vit_1 = ViT_B16(weights = "imagenet21k")
    vit_2 = ViT_B16(weights="imagenet21k+imagenet2012")
    
  3. Build ViTB16 without pre trained weights

    from vit_tf2.vit import ViT_B16
    vit = ViT_B16(pre_trained=False)
    

    The pre training weights file will be downloaded to C:\Users\user_name\. Keras\weights when "pre_trained = True".

  4. Build pre trained ViTB32 with custom parameters

    from vit_tf2.vit import ViT_B32
    vit = ViT_B32(
        image_size = 128,
        num_classes = 12, 
        pre_logits = False,
        weights = "imagenet21k",
        )
    

    When you change some model parameters and some layers change, these layers will not load pre trained weights, the unchanged layers will still load pre trained weights. You can use loading_summary() to view specific information.

    vit.loading_summary()
    >>
    Model: "ViT-B-32-128"
    -----------------------------------------------------------------
    layers                             load weights inf
    =================================================================
    patch_embedding                    loaded
    
    add_cls_token                      loaded - imagenet
    
    position_embedding                 not loaded - mismatch
    
    transformer_block_0                loaded - imagenet
    
    transformer_block_1                loaded - imagenet
    
    transformer_block_2                loaded - imagenet
    
    transformer_block_3                loaded - imagenet
    
    transformer_block_4                loaded - imagenet
    
    transformer_block_5                loaded - imagenet
    
    transformer_block_6                loaded - imagenet
    
    transformer_block_7                loaded - imagenet
    
    transformer_block_8                loaded - imagenet
    
    transformer_block_9                loaded - imagenet
    
    transformer_block_10               loaded - imagenet
    
    transformer_block_11               loaded - imagenet
    
    layer_norm                         loaded - imagenet
    
    mlp_head                           not loaded - mismatch
    =================================================================
    

Q3: How to build a custom ViT?

  1. Instantiating ViT classes to build custom ViT models

    from vit_tf2.vit import ViT
    vit = ViT(
        image_size = 128,
        patch_size = 36,
        num_classes = 1,
        hidden_dim = 128,
        mlp_dim = 512,
        atten_heads = 32,
        encoder_depth = 4,
        dropout_rate = 0.1,
        activation = "sigmoid",
        pre_logits = True,
        include_mlp_head = True,
        )
    vit.summary()
    
    >>
    Model: "ViT-CUSTOM_SIZE-36-128"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #
    =================================================================
     patch_embedding (PatchEmbed  (None, 9, 128)           497792
     ding)
    
     add_cls_token (AddCLSToken)  (None, 10, 128)          128
    
     position_embedding (AddPosi  (None, 10, 128)          1280
     tionEmbedding)
    
     transformer_block_0 (Transf  (None, 10, 128)          198272
     ormerEncoder)
    
     transformer_block_1 (Transf  (None, 10, 128)          198272
     ormerEncoder)
    
     transformer_block_2 (Transf  (None, 10, 128)          198272
     ormerEncoder)
    
     transformer_block_3 (Transf  (None, 10, 128)          198272
     ormerEncoder)
    
     layer_norm (LayerNormalizat  (None, 10, 128)          256
     ion)
    
     extract_token (Lambda)      (None, 128)               0
    
     pre_logits (Dense)          (None, 128)               16512
    
     mlp_head (Dense)            (None, 1)                 129
    
    =================================================================
    Total params: 1,309,185
    Trainable params: 1,309,185
    Non-trainable params: 0
    _________________________________________________________________==========================
    

    It should be noted that "hidden_dim" should be divisible by "atten_heads". It is best to set "image_size" size that can be evenly divided by "patch_size".

  2. Load pre trained weights for custom model

    from vit_tf2 import utils, vit
    vit_custom = vit.ViT(
        image_size=128,
        patch_size=8,
        encoder_depth=4
        )
    utils.load_imgnet_weights(vit_custom, "ViT-B_16_imagenet21k.npz")
    vit_custom.loading_summary()
    
    >>
    Model: "ViT-CUSTOM_SIZE-8-128"
    -----------------------------------------------------------------
    layers                             load weights inf
    =================================================================
    patch_embedding                    mismatch
    
    add_cls_token                      loaded - imagenet
    
    position_embedding                 not loaded - mismatch
    
    transformer_block_0                loaded - imagenet
    
    transformer_block_1                loaded - imagenet
    
    transformer_block_2                loaded - imagenet
    
    transformer_block_3                loaded - imagenet
    
    layer_norm                         loaded - imagenet
    
    pre_logits                         loaded - imagenet
    
    mlp_head                           not loaded - mismatch
    =================================================================
    

Q4: Fine tuning or image classification on pre trained ViT ?

  1. Fine tuning pre trained ViT

    from vit_tf2.vit import ViT_L16
    
    # Set parameters
    IMAGE_SIZE = ...
    NUM_CLASSES = ...
    ACTIVATION = ...
    ...
    
    # build ViT
    vit = ViT_B32(
        image_size = IMAGE_SIZE,
        num_classes = NUM_CLASSES, 
        activation = ACTIVATION,
        )
    
    # Compiling ViT
    vit.compile(
        optimizer = ...,
        loss = ...,
        metrics = ...
        )
    
    # Define train, valid and test data
    train_generator = ...
    valid_generator = ...
    test_generator  = ...
    
    # fine tuning ViT
    vit.fit(
        x = train_generator ,
        validation_data = valid_generator ,
        steps_per_epoch = ...,
        validation_steps = ...,
        )
    
    # testing
    vit.evaluate(x = test_generator, steps=...)
    
  2. Applying pre trained ViT for Image Classification

    from vit_tf2 import vit
    from vit_tf2 import utils
    
    # Get pre-trained vitb16
    vit_model = vit.ViT_B16(weights="imagenet21k+imagenet2012")
    
    # Load a picture
    img = utils.read_img("test.jpg", resize=vit_model.image_size)
    img = img.reshape((1,*vit_model.image_size,3))
    
    # Classifying
    y = vit_model.predict(img)
    classes = utils.get_imagenet2012_classes()
    print(classes[y[0].argmax()])
    

    It should be noted that as there is currently no label for "imagenet21k", please use "imagenet21k+imagenet2012" when applying pre trained ViT. Both "imagenet21k" and "imagenet21k+imagenet2012" are available during the fine-tuning stage.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit-tf2-1.0.2.tar.gz (4.2 kB view details)

Uploaded Source

Built Distribution

vit_tf2-1.0.2-py3-none-any.whl (4.3 kB view details)

Uploaded Python 3

File details

Details for the file vit-tf2-1.0.2.tar.gz.

File metadata

  • Download URL: vit-tf2-1.0.2.tar.gz
  • Upload date:
  • Size: 4.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.8

File hashes

Hashes for vit-tf2-1.0.2.tar.gz
Algorithm Hash digest
SHA256 8bbd8b87ae07ba2178e0578526b085baabeca1e7b7d87e9122d348de2cc52a34
MD5 a06712c1d609cb74adf2e4ff426ece3d
BLAKE2b-256 9361d9f9fd50b09c6a12f5e1c403aada48a961f129e9864a9a1e7b4788495f9d

See more details on using hashes here.

File details

Details for the file vit_tf2-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: vit_tf2-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 4.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.8

File hashes

Hashes for vit_tf2-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9f328094f20fedf1027d4f9c29021073f969b20d21533c139428c161a0892d86
MD5 192707bf35cff79b44602359bd1cebe3
BLAKE2b-256 f290e4b07b43b286a109c81943b70077b86a09a5cd89997a363558b0eeb91ebd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page