Visual Transformers (ViT) in PyTorch.
Project description
ViT PyTorch
Quickstart
Install with pip install pytorch_pretrained_vit
and load a pretrained ViT with:
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
Or find a Google Colab example here.
Overview
This repository contains an op-for-op PyTorch reimplementation of the Visual Transformer architecture from Google, along with pre-trained models and examples.
The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.
At the moment, you can easily:
- Load pretrained ViT models
- Evaluate on ImageNet or your own data
- Finetune ViT on your own dataset
(Upcoming features) Coming soon:
- Train ViT from scratch on ImageNet (1K)
- Export to ONNX for efficient inference
Table of contents
About ViT
Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.
The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.
About ViT-PyTorch
ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the original Jax implementation, so that it's easy to load Jax-pretrained weights.
At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.
Installation
Install with pip:
pip install pytorch_pretrained_vit
Or from source:
git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
pip install -e .
Usage
Loading pretrained models
Loading a pretrained model is easy:
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
Details about the models are below:
Name | * Pretrained on * | Finetuned on | *Available? * |
---|---|---|---|
B_16 |
ImageNet-21k | - | ✓ |
B_32 |
ImageNet-21k | - | ✓ |
L_16 |
ImageNet-21k | - | - |
L_32 |
ImageNet-21k | - | ✓ |
B_16_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
B_32_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
L_16_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
L_32_imagenet1k |
ImageNet-21k | ImageNet-1k | ✓ |
Custom ViT
Loading custom configurations is just as easy:
from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)
Example: Classification
Below is a simple, complete example. It may also be found as a Jupyter notebook in examples/simple
or as a Colab Notebook.
import json
from PIL import Image
import torch
from torchvision import transforms
# Load ViT
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()
# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])
# Classify
with torch.no_grad():
outputs = model(img)
print(outputs.shape) # (1, 1000)
ImageNet
See examples/imagenet
for details about evaluating on ImageNet.
Credit
Other great repositories with this model include:
Contributing
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
I look forward to seeing what the community does with these models!
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file pytorch-pretrained-vit-0.0.7.tar.gz
.
File metadata
- Download URL: pytorch-pretrained-vit-0.0.7.tar.gz
- Upload date:
- Size: 13.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.24.0 setuptools/50.3.0.post20201006 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2c74057c5898c63c0076ac518645bcea0f380b8b1aa73dbb8fedbedec757e6ec |
|
MD5 | f868742c7a9a78993493e1fdfaf128bf |
|
BLAKE2b-256 | 028db404fe410a984ce2bc95a8ce02d397e3b8b12d6dd3118db6ac9b8edaa370 |