Skip to main content

Keras implementation of ViT (Vision Transformer)

Project description

vit-keras

This is a Keras implementation of the models described in An Image is Worth 16x16 Words: Transformes For Image Recognition at Scale. It is based on an earlier implementation from tuvovan, modified to match the Flax implementation in the official repository.

The weights here are ported over from the weights provided in the official repository. See utils.load_weights_numpy to see how this is done (it's not pretty, but it does the job).

Usage

Install this package using pip install vit-keras

You can use the model out-of-the-box with ImageNet 2012 classes using something like the following. The weights will be downloaded automatically.

from vit_keras import vit, utils

image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)
url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'
image = utils.read(url, image_size)
X = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)
y = model.predict(X)
print(classes[y[0].argmax()]) # Granny smith

You can fine-tune using a model loaded as follows.

image_size = 224
model = vit.vit_l32(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=False,
    classes=200
)
# Train this model on your data as desired.

Visualizing Attention Maps

There's some functionality for plotting attention maps for a given image and model. See example below. I'm not sure I'm doing this correctly (the official repository didn't have example code). Feedback /corrections welcome!

import numpy as np
import matplotlib.pyplot as plt
from vit_keras import vit, utils, visualize

# Load a model
image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)
classes = utils.get_imagenet_classes()

# Get an image and compute the attention map
url = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Free%21_%283987584939%29.jpg'
image = utils.read(url, image_size)
attention_map = visualize.attention_map(model=model, image=image)
print('Prediction:', classes[
    model.predict(vit.preprocess_inputs(image)[np.newaxis])[0].argmax()]
)  # Prediction: Eskimo dog, husky

# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)

example of attention map

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vit_keras-0.2.0.tar.gz (140.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

vit_keras-0.2.0-py3-none-any.whl (24.6 kB view details)

Uploaded Python 3

File details

Details for the file vit_keras-0.2.0.tar.gz.

File metadata

  • Download URL: vit_keras-0.2.0.tar.gz
  • Upload date:
  • Size: 140.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.6

File hashes

Hashes for vit_keras-0.2.0.tar.gz
Algorithm Hash digest
SHA256 fcff0397f94187823cbf8f5a453b7836b1cd365a7c9e6c422c2e959f8babb1dc
MD5 c886149f10ce57d9759ba9e926f5e254
BLAKE2b-256 d29c2cc182b43a0924aa23f03a5b4d22052d5353dfb06ee2bbc9ba6ee4dc2026

See more details on using hashes here.

File details

Details for the file vit_keras-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: vit_keras-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 24.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.6

File hashes

Hashes for vit_keras-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a1a48b8cbe01895d420cb2b4e96512e7839022347c8733a8fb223fb8fec7510d
MD5 004985716d0cccea2c10eeacb2482809
BLAKE2b-256 660db2ea088fd10306b9914229e6eb7e82cb02ac0daaa3e81aff7412b75c278b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page