Skip to main content

VisionTransformer implemented in PyTorch.

Project description

Vision Transformer Pytorch

This project is modified from lukemelas/EfficientNet-PyTorch and asyml/vision-transformer-pytorch to provide out-of-box API for you to utilize VisionTransformer as easy as EfficientNet.

Quickstart

Install with pip install vision_transformer_pytorch and load a pretrained VisionTransformer with:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')

About Vision Transformer PyTorch

Vision Transformer Pytorch is a PyTorch re-implementation of Vision Transformer based on one of the best practice of commonly utilized deep learning libraries, EfficientNet-PyTorch, and an elegant implement of VisionTransformer, vision-transformer-pytorch. In this project, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

If you have any feature requests or questions, feel free to leave them as GitHub issues!

Installation

Install via pip:

pip install vision_transformer_pytorch

Or install from source:

git clone https://github.com/tczhangzhi/VisionTransformer-Pytorch
cd VisionTransformer-Pytorch
pip install -e .

Usage

Loading pretrained models

Load an EfficientNet:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16')

Load a pretrained EfficientNet:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')
# inputs = torch.randn(1, 3, *model.image_size)
# model(inputs)
# model.extract_features(inputs)

Default hyper parameters:

Param\Model ViT-B_16 ViT-B_32 ViT-L_16 ViT-L_32
image_size 384 384 384 384
patch_size 16 32 16 32
emb_dim 768 768 1024 1024
mlp_dim 3072 3072 4096 4096
num_heads 12 12 16 16
num_layers 12 12 24 24
num_classes 1000 1000 1000 1000
attn_dropout_rate 0.0 0.0 0.0 0.0
dropout_rate 0.1 0.1 0.1 0.1

If you need to modify these hyper parameters, please use:

from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16', image_size=256, patch_size=64, ...)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

vision_transformer_pytorch-1.0.2.tar.gz (8.6 kB view details)

Uploaded Source

Built Distribution

vision_transformer_pytorch-1.0.2-py2.py3-none-any.whl (11.6 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file vision_transformer_pytorch-1.0.2.tar.gz.

File metadata

  • Download URL: vision_transformer_pytorch-1.0.2.tar.gz
  • Upload date:
  • Size: 8.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.8.5

File hashes

Hashes for vision_transformer_pytorch-1.0.2.tar.gz
Algorithm Hash digest
SHA256 7969abceb97ac82262338f2f013578f771773afdc8cecb22ba5f3291b424f907
MD5 2132e81b04d88f82a34698bbd4fd7b84
BLAKE2b-256 333533c70adf1b1d4df3060c9bc8a3ae7ec44c7327e5a8fdf81f7d804add31d0

See more details on using hashes here.

File details

Details for the file vision_transformer_pytorch-1.0.2-py2.py3-none-any.whl.

File metadata

  • Download URL: vision_transformer_pytorch-1.0.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.8.5

File hashes

Hashes for vision_transformer_pytorch-1.0.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 4832525acbdd380b3b3a1747244563474c251c6c4ce9ea41f08b7712860fd779
MD5 eedf31021518e3d5082e7d90cf1c7a58
BLAKE2b-256 a4de482bc185ae2f8ae88e40466e158b4a8e26657af0f090be6f4a4ee8bbcd04

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page