VisionTransformer implemented in PyTorch.
Project description
Vision Transformer Pytorch
This project is modified from lukemelas/EfficientNet-PyTorch and asyml/vision-transformer-pytorch to provide out-of-box API for you to utilize VisionTransformer as easy as EfficientNet.
Quickstart
Install with pip install vision_transformer_pytorch
and load a pretrained VisionTransformer with:
from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')
About Vision Transformer PyTorch
Vision Transformer Pytorch is a PyTorch re-implementation of Vision Transformer based on one of the best practice of commonly utilized deep learning libraries, EfficientNet-PyTorch, and an elegant implement of VisionTransformer, vision-transformer-pytorch. In this project, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.
If you have any feature requests or questions, feel free to leave them as GitHub issues!
Installation
Install via pip:
pip install vision_transformer_pytorch
Or install from source:
git clone https://github.com/tczhangzhi/VisionTransformer-Pytorch
cd VisionTransformer-Pytorch
pip install -e .
Usage
Loading pretrained models
Load an EfficientNet:
from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16')
Load a pretrained EfficientNet:
from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_pretrained('ViT-B_16')
# inputs = torch.randn(1, 3, *model.image_size)
# model(inputs)
# model.extract_features(inputs)
Default hyper parameters:
Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 | R50+ViT-B_16 |
---|---|---|---|---|---|
image_size | 384 | 384 | 384 | 384 | 384 |
patch_size | 16 | 32 | 16 | 32 | 1 |
emb_dim | 768 | 768 | 1024 | 1024 | 768 |
mlp_dim | 3072 | 3072 | 4096 | 4096 | 3072 |
num_heads | 12 | 12 | 16 | 16 | 12 |
num_layers | 12 | 12 | 24 | 24 | 12 |
num_classes | 1000 | 1000 | 1000 | 1000 | 1000 |
attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 |
If you need to modify these hyper parameters, please use:
from vision_transformer_pytorch import VisionTransformer
model = VisionTransformer.from_name('ViT-B_16', image_size=256, patch_size=64, ...)
ImageNet
See examples/imagenet
for details about evaluating on ImageNet.
Contributing
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
I look forward to seeing what the community does with these models!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file vision_transformer_pytorch-1.0.3.tar.gz
.
File metadata
- Download URL: vision_transformer_pytorch-1.0.3.tar.gz
- Upload date:
- Size: 10.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 820597ecc56f02362c82e74f0a9e4434052c2166383bac8535b78162c485bf55 |
|
MD5 | 501596a15f3afb262dce21178748b5d6 |
|
BLAKE2b-256 | f77a6a66b49cfca9b767cbef24c13dcd1bd2236b181e040b5e45ceb7afb1f17e |
File details
Details for the file vision_transformer_pytorch-1.0.3-py2.py3-none-any.whl
.
File metadata
- Download URL: vision_transformer_pytorch-1.0.3-py2.py3-none-any.whl
- Upload date:
- Size: 13.6 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.1.post20201107 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0652e15527e87d4ce34d5e07153adf596f7570ccf05b45ec274c4968f3311fee |
|
MD5 | a9984e7893639a1bcc74ad3fa0be65b7 |
|
BLAKE2b-256 | af28b781e01992e2ad2271f0a89c5a813a7e5bd758af0a45105a06c5ecf0bdb0 |