Skip to main content

A PyTorch vision transformer for distillation.

Project description

PyTorch Vision Transformers with Distillation

Based on the paper "Training data-efficient image transformers & distillation through attention".

This repository will allow you to use distillation techniques with vision transformers in PyTorch. Most importantly, you can use pretrained models for the teacher, the student, or even both! My motivation was to use transfer learning to decrease the amount of resources it takes to train a vision transformer.

Quickstart

Install with pip install distillable_vision_transformer and load a pretrained transformer with:

from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')

Installation

Install via pip:

pip install distillable_vision_transformer

Or install from source:

git clone https://github.com/Graeme22/DistillableVisionTransformer.git
cd DistillableVisionTransformer
pip install -e .

Usage

Load a model architecture:

from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_name('ViT-B_16')

Load a pretrained model:

from distillable_vision_transformer import DistillableVisionTransformer
model = DistillableVisionTransformer.from_pretrained('ViT-B_16')

Default hyper parameters:

Param\Model ViT-B_16 ViT-B_32 ViT-L_16 ViT-L_32
image_size 384, 384 384, 384 384, 384 384, 384
patch_size 16 32 16 32
emb_dim 768 768 1024 1024
mlp_dim 3072 3072 4096 4096
num_heads 12 12 16 16
num_layers 12 12 24 24
num_classes 1000 1000 1000 1000
attn_dropout_rate 0.0 0.0 0.0 0.0
dropout_rate 0.1 0.1 0.1 0.1

If you need to modify these hyperparameters, just overwrite them:

model = DistillableVisionTransformer.from_name('ViT-B_16', patch_size=64, emb_dim=2048, ...)

Training

Wrap the student (instance of DistillableVisionTransformer) and the teacher (any network that you want to use to train the student) with a DistillationTrainer:

from distillable_vision_transformer import DistillableVisionTransformer, DistillationTrainer

student = DistillableVisionTransformer.from_pretrained('ViT-B_16')
trainer = DistillationTrainer(teacher=teacher, student=student) # where teacher is some pretrained network, e.g. an EfficientNet

For the loss function, it is recommended that you use the DistilledLoss class, which is a kind of hybrid between cross-entropy and KL-divergence loss. It takes as arguments teacher_logits, student_logits, and distill_logits, which are obtained from the forward pass on DistillationTrainer, as well as the true labels labels.

from distillable_vision_transformer import DistilledLoss

loss_fn = DistilledLoss(alpha=0.5, temperature=3.0)
loss = loss_fn(teacher_logits, student_logits, distill_logits, labels)

Inference

For inference, we want to use the DistillableVisionTransformer instance, not its DistillationTrainer wrapper.

import torch
from distillable_vision_transformer import DistillableVisionTransformer

model = DistillableVisionTransformer.from_pretrained('ViT-B_16')
model.eval()

inputs = torch.rand(1, 3, *model.image_size)
# we can discard the distillation tokens, as they are only needed to calculate loss
outputs, _ = model(inputs)

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

distillable_vision_transformer-1.0.1.tar.gz (8.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file distillable_vision_transformer-1.0.1.tar.gz.

File metadata

  • Download URL: distillable_vision_transformer-1.0.1.tar.gz
  • Upload date:
  • Size: 8.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/51.1.1 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for distillable_vision_transformer-1.0.1.tar.gz
Algorithm Hash digest
SHA256 a7cc2daa4f455ab3705080c432e5a88db012764bb3de59968c894c621ba5f572
MD5 b035ced1c8b599fa8fccc5ae0877f85a
BLAKE2b-256 bbdba79003717000e0b32820bd1a1eb38dd72e9f76215b336d7b1609513d0afe

See more details on using hashes here.

File details

Details for the file distillable_vision_transformer-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: distillable_vision_transformer-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/51.1.1 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for distillable_vision_transformer-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c853401eae3548f0f713680147537db4337515eb399def47ac5a7d4099bca6bb
MD5 4a8fc16ab26418c07eadde146ec34843
BLAKE2b-256 eb79a5212a21a0c6cc9ec76e93ddffbb30440778bf07fead70fc018526269a6f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page