Skip to main content

VisionTransformer implemented in PyTorch.

Project description

Vision Transformer Pytorch

This project is modified from lukemelas/EfficientNet-PyTorch and asyml/vision-transformer-pytorch to provide out-of-box API for you to utilize VisionTransformer as easy as EfficientNet.

Quickstart

Install with pip install vision_transformer_pytorch and load a pretrained VisionTransformer with:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')

About Vision Transformer PyTorch

Vision Transformer Pytorch is a PyTorch re-implementation of Vision Transformer based on one of the best practice of commonly utilized deep learning libraries, EfficientNet-PyTorch, and an elegant implement of VisionTransformer, vision-transformer-pytorch. In this project, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

If you have any feature requests or questions, feel free to leave them as GitHub issues!

Installation

Install via pip:

pip install vision_transformer_pytorch

Or install from source:

git clone https://github.com/tczhangzhi/VisionTransformer-Pytorch
cd VisionTransformer-Pytorch
pip install -e .

Usage

Loading pretrained models

Load an EfficientNet:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16')

Load a pretrained EfficientNet:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')
# inputs = torch.randn(1, 3, *model.image_size)
# model(inputs)
# model.extract_features(inputs)

Default hyper parameters:

Param\Model ViT-B_16 ViT-B_32 ViT-L_16 ViT-L_32 R50+ViT-B_16
image_size 384 384 384 384 384
patch_size 16 32 16 32 1
emb_dim 768 768 1024 1024 768
mlp_dim 3072 3072 4096 4096 3072
num_heads 12 12 16 16 12
num_layers 12 12 24 24 12
num_classes 1000 1000 1000 1000 1000
attn_dropout_rate 0.0 0.0 0.0 0.0 0.0
dropout_rate 0.1 0.1 0.1 0.1 0.1

If you need to modify these hyper parameters, please use:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16', image_size=256, patch_size=64, ...)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vision_transformer_pytorch-1.0.3.tar.gz (10.3 kB view details)

Uploaded Source

Built Distribution

vision_transformer_pytorch-1.0.3-py2.py3-none-any.whl (13.6 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file vision_transformer_pytorch-1.0.3.tar.gz.

File metadata

  • Download URL: vision_transformer_pytorch-1.0.3.tar.gz
  • Upload date:
  • Size: 10.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.8.5

File hashes

Hashes for vision_transformer_pytorch-1.0.3.tar.gz
Algorithm Hash digest
SHA256 820597ecc56f02362c82e74f0a9e4434052c2166383bac8535b78162c485bf55
MD5 501596a15f3afb262dce21178748b5d6
BLAKE2b-256 f77a6a66b49cfca9b767cbef24c13dcd1bd2236b181e040b5e45ceb7afb1f17e

See more details on using hashes here.

File details

Details for the file vision_transformer_pytorch-1.0.3-py2.py3-none-any.whl.

File metadata

  • Download URL: vision_transformer_pytorch-1.0.3-py2.py3-none-any.whl
  • Upload date:
  • Size: 13.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.8.5

File hashes

Hashes for vision_transformer_pytorch-1.0.3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 0652e15527e87d4ce34d5e07153adf596f7570ccf05b45ec274c4968f3311fee
MD5 a9984e7893639a1bcc74ad3fa0be65b7
BLAKE2b-256 af28b781e01992e2ad2271f0a89c5a813a7e5bd758af0a45105a06c5ecf0bdb0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page