A PyTorch vision transformer for distillation.
Project description
PyTorch Vision Transformers with Distillation
Based on the paper "Training data-efficient image transformers & distillation through attention".
This repository will allow you to use distillation techniques with vision transformers in PyTorch. Most importantly, you can use pretrained models for the teacher, the student, or even both! My motivation was to use transfer learning to decrease the amount of resources it takes to train a vision transformer.
Quickstart
Install with pip install distillable_vision_transformer and load a pretrained transformer with:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Installation
Install via pip:
pip install distillable_vision_transformer
Or install from source:
git clone https://github.com/Graeme22/DistillableVisionTransformer.git
cd DistillableVisionTransformer
pip install -e .
Usage
Load a model architecture:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_name('ViT-B_16')
Load a pretrained model:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Default hyper parameters:
| Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 |
|---|---|---|---|---|
| image_size | 384, 384 | 384, 384 | 384, 384 | 384, 384 |
| patch_size | 16 | 32 | 16 | 32 |
| emb_dim | 768 | 768 | 1024 | 1024 |
| mlp_dim | 3072 | 3072 | 4096 | 4096 |
| num_heads | 12 | 12 | 16 | 16 |
| num_layers | 12 | 12 | 24 | 24 |
| num_classes | 1000 | 1000 | 1000 | 1000 |
| attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 |
| dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 |
If you need to modify these hyperparameters, just overwrite them:
model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...)
Training
Wrap the student (instance of DistillableVisionTransformer) and the teacher (any network that you want to use to train the student) with a DistillationTrainer:
from distillable_vision_transformer import DistillableVisionTransformer, DistillationTrainer
student = DistillableVisionTransformer.from_pretrained('ViT-B_16')
trainer = DistillationTrainer(teacher=teacher, student=student) # where teacher is some pretrained network, e.g. an EfficientNet
For the loss function, it is recommended that you use the DistilledLoss class, which is a kind of hybrid between cross-entropy and KL-divergence loss.
It takes as arguments teacher_logits, student_logits, and distill_logits, which are obtained from the forward pass on DistillationTrainer, as well as the true labels labels.
from distillable_vision_transformer import DistilledLoss
loss_fn = DistilledLoss(alpha=0.5, temperature=3.0)
loss = loss_fn(teacher_logits, student_logits, distill_logits, labels)
Inference
For inference, we want to use the DistillableVisionTransformer instance, not its DistillationTrainer wrapper.
import torch
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
model.eval()
inputs = torch.rand(1, 3, *model.image_size)
# we can discard the distillation tokens, as they are only needed to calculate loss
outputs, _ = model(inputs)
Contributing
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file distillable_vision_transformer-1.0.1.tar.gz.
File metadata
- Download URL: distillable_vision_transformer-1.0.1.tar.gz
- Upload date:
- Size: 8.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/51.1.1 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a7cc2daa4f455ab3705080c432e5a88db012764bb3de59968c894c621ba5f572
|
|
| MD5 |
b035ced1c8b599fa8fccc5ae0877f85a
|
|
| BLAKE2b-256 |
bbdba79003717000e0b32820bd1a1eb38dd72e9f76215b336d7b1609513d0afe
|
File details
Details for the file distillable_vision_transformer-1.0.1-py3-none-any.whl.
File metadata
- Download URL: distillable_vision_transformer-1.0.1-py3-none-any.whl
- Upload date:
- Size: 12.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/51.1.1 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c853401eae3548f0f713680147537db4337515eb399def47ac5a7d4099bca6bb
|
|
| MD5 |
4a8fc16ab26418c07eadde146ec34843
|
|
| BLAKE2b-256 |
eb79a5212a21a0c6cc9ec76e93ddffbb30440778bf07fead70fc018526269a6f
|