A PyTorch vision transformer for distillation.
Project description
PyTorch Vision Transformers with Distillation
Based on the paper "Training data-efficient image transformers & distillation through attention".
This repository will allow you to use distillation techniques with vision transformers in PyTorch. Most importantly, you can use pretrained models for the teacher, the student, or even both! My motivation was to use transfer learning to decrease the amount of resources it takes to train a vision transformer.
Quickstart
Install with pip install distillable_vision_transformer
and load a pretrained transformer with:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Installation
Install via pip:
pip install distillable_vision_transformer
Or install from source:
git clone https://github.com/Graeme22/DistillableVisionTransformer.git
cd DistillableVisionTransformer
pip install -e .
Usage
Load a model architecture:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_name('ViT-B_16')
Load a pretrained model:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Default hyper parameters:
Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 |
---|---|---|---|---|
image_size | 384, 384 | 384, 384 | 384, 384 | 384, 384 |
patch_size | 16 | 32 | 16 | 32 |
emb_dim | 768 | 768 | 1024 | 1024 |
mlp_dim | 3072 | 3072 | 4096 | 4096 |
num_heads | 12 | 12 | 16 | 16 |
num_layers | 12 | 12 | 24 | 24 |
num_classes | 1000 | 1000 | 1000 | 1000 |
attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 |
dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 |
If you need to modify these hyperparameters, just overwrite them:
model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...)
Contributing
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for distillable_vision_transformer-1.0.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 20663a57754c7656aef57d379de5afb668f39c0de12e1acb03cceda5a4bccdb4 |
|
MD5 | 05d484b99031d2351ea483419b490275 |
|
BLAKE2b-256 | 079da8dee88335161b20c32fe8759d1245f23d644cd3386d3dd102e417639b1b |
Hashes for distillable_vision_transformer-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 802fba5be13833b9615471d9db1a0e8a8f289b7342e65c8f435356351c111376 |
|
MD5 | 6fd6d3ce1fadcfd19a7ce6461337b9c8 |
|
BLAKE2b-256 | bef65f911dd89848ad70f90df33f3900424634d43cce368f0a9dd92460f33b98 |