A PyTorch vision transformer for distillation.
Project description
PyTorch Vision Transformers with Distillation
Based on the paper "Training data-efficient image transformers & distillation through attention".
This repository will allow you to use distillation techniques with vision transformers in PyTorch. Most importantly, you can use pretrained models for the teacher, the student, or even both! My motivation was to use transfer learning to decrease the amount of resources it takes to train a vision transformer.
Quickstart
Install with pip install distillable_vision_transformer
and load a pretrained transformer with:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Installation
Install via pip:
pip install distillable_vision_transformer
Or install from source:
git clone https://github.com/Graeme22/DistillableVisionTransformer.git
cd DistillableVisionTransformer
pip install -e .
Usage
Load a model architecture:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_name('ViT-B_16')
Load a pretrained model:
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
Default hyper parameters:
Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 |
---|---|---|---|---|
image_size | 384, 384 | 384, 384 | 384, 384 | 384, 384 |
patch_size | 16 | 32 | 16 | 32 |
emb_dim | 768 | 768 | 1024 | 1024 |
mlp_dim | 3072 | 3072 | 4096 | 4096 |
num_heads | 12 | 12 | 16 | 16 |
num_layers | 12 | 12 | 24 | 24 |
num_classes | 1000 | 1000 | 1000 | 1000 |
attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 |
dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 |
If you need to modify these hyperparameters, just overwrite them:
model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...)
Training
Wrap the student (instance of DistillableVisionTransformer
) and the teacher (any network that you want to use to train the student) with a DistillationTrainer
:
from distillable_vision_transformer import DistillableVisionTransformer, DistillationTrainer
student = DistillableVisionTransformer.from_pretrained('ViT-B_16')
trainer = DistillationTrainer(teacher=teacher, student=student) # where teacher is some pretrained network, e.g. an EfficientNet
For the loss function, it is recommended that you use the DistilledLoss
class, which is a kind of hybrid between cross-entropy and KL-divergence loss.
It takes as arguments teacher_logits
, student_logits
, and distill_logits
, which are obtained from the forward pass on DistillationTrainer
, as well as the true labels labels
.
from distillable_vision_transformer import DistilledLoss
loss_fn = DistilledLoss(alpha=0.5, temperature=3.0)
loss = loss_fn(teacher_logits, student_logits, distill_logits, labels)
Inference
For inference, we want to use the DistillableVisionTransformer
instance, not its DistillationTrainer
wrapper.
import torch
from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
model.eval()
inputs = torch.rand(1, 3, *model.image_size)
# we can discard the distillation tokens, as they are only needed to calculate loss
outputs, _ = model(inputs)
Contributing
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for distillable_vision_transformer-1.0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | a7cc2daa4f455ab3705080c432e5a88db012764bb3de59968c894c621ba5f572 |
|
MD5 | b035ced1c8b599fa8fccc5ae0877f85a |
|
BLAKE2b-256 | bbdba79003717000e0b32820bd1a1eb38dd72e9f76215b336d7b1609513d0afe |
Hashes for distillable_vision_transformer-1.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c853401eae3548f0f713680147537db4337515eb399def47ac5a7d4099bca6bb |
|
MD5 | 4a8fc16ab26418c07eadde146ec34843 |
|
BLAKE2b-256 | eb79a5212a21a0c6cc9ec76e93ddffbb30440778bf07fead70fc018526269a6f |