Skip to main content

Visual Transformers (ViT) in PyTorch.

Project description

TODO

Urgent

  • Rewrite README
  • Implement model
  • Convert pretrained weights
  • Colab example
  • pip

ViT PyTorch

Quickstart

Install with pip install vit_pytorch and load a pretrained ViT with:

from vit_pytorch import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Or find a Google Colab example here.

Overview

This repository contains an op-for-op PyTorch reimplementation of the Visual Transformer architecture from Google, along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.

At the moment, you can easily:

  • Load pretrained ViT models
  • Evaluate on ImageNet or your own data

(Upcoming features) Coming soon:

  • Finetune ViT on your own dataset
  • Train ViT from scratch on ImageNet (1K)
  • Export to ONNX for efficient inference

Table of contents

  1. About ViT
  2. About ViT-PyTorch
  3. Installation
  4. Usage
  5. Contributing

About ViT

Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.

The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.

About ViT-PyTorch

ViT-PyTorch is a PyTorch re-implementation of EfficientNet. It is consistent with the original Jax implementation, so that it's easy to load Jax-pretrained weights.

At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

Installation

Install with pip:

pip install vit_pytorch

Or from source:

git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
pip install -e .

Usage

Loading pretrained models

Loading a pretrained model is easy:

from vit_pytorch import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Details about the models are below:

Name * Pretrained on * Finetuned on *Available? *
B_16 ImageNet-21k -
B_16_imagenet1k ImageNet-21k ImageNet-1k
B_32_imagenet1k ImageNet-21k ImageNet-1k
L_16_imagenet1k ImageNet-21k ImageNet-1k
L_32_imagenet1k ImageNet-21k ImageNet-1k

Custom ViT

Loading custom configurations is just as easy:

from vit_pytorch import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)

Example: Classification

Below is a simple, complete example. It may also be found as a Jupyter notebook in examples/simple or as a Colab Notebook.

import json
from PIL import Image
import torch
from torchvision import transforms

# Load ViT
from vit_pytorch import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()

# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
    transforms.Resize((384, 384)), 
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])

# Classify
with torch.no_grad():
    outputs = model(img)
print(outputs.shape)  # (1, 1000)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Credit

Other great repositories with this model include:

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pretrained-vit-pytorch-0.0.1.tar.gz (10.0 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page