A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization
Project description
KD-Lib
A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization
Installation
From source (recommended)
https://github.com/SforAiDl/KD_Lib.git
cd KD_Lib
python setup.py install
From PyPI
pip install KD-Lib
Example usage
To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot loss curves:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD
# This part is where you define your datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
teacher_model = <your model>
student_model = <your model>
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)
# Now, this is where KD_Lib comes into the picture
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader,
teacher_optimizer, student_optimizer)
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True) # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True) # Train the student network
distiller.evaluate(teacher=False) # Evaluate the student network
distiller.get_parameters() # A utility function to get the number of
# parameters in the teacher and the student network
To train a collection of 3 models in an online fashion using the framework in Deep Mutual Learning and log training details to Tensorboard:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50 # To use models packaged in KD_Lib
# Define your datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)
student_cohort = [student_model_1, student_model_2]
student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)
student_optimizers = [student_optimizer_1, student_optimizer_2]
# Now, this is where KD_Lib comes into the picture
distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./logs")
distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()
Methods Implemented
Some benchmark results can be found in the logs file.
Paper / Method | Link | Repository (KD_Lib/) |
---|---|---|
Distilling the Knowledge in a Neural Network | https://arxiv.org/abs/1503.02531 | KD/vision/vanilla |
Improved Knowledge Distillation via Teacher Assistant | https://arxiv.org/abs/1902.03393 | KD/vision/TAKD |
Relational Knowledge Distillation | https://arxiv.org/abs/1904.05068 | KD/vision/RKD |
Distilling Knowledge from Noisy Teachers | https://arxiv.org/abs/1610.09650 | KD/vision/noisy |
Paying More Attention To The Attention | https://arxiv.org/abs/1612.03928 | KD/vision/attention |
Revisit Knowledge Distillation: a Teacher-free Framework |
https://arxiv.org/abs/1909.11723 | KD/vision/teacher_free |
Mean Teachers are Better Role Models | https://arxiv.org/abs/1703.01780 | KD/vision/mean_teacher |
Knowledge Distillation via Route Constrained Optimization |
https://arxiv.org/abs/1904.09149 | KD/vision/RCO |
Born Again Neural Networks | https://arxiv.org/abs/1805.04770 | KD/vision/BANN |
Preparing Lessons: Improve Knowledge Distillation with Better Supervision |
https://arxiv.org/abs/1911.07471 | KD/vision/KA |
Improving Generalization Robustness with Noisy Collaboration in Knowledge Distillation |
https://arxiv.org/abs/1910.05057 | KD/vision/noisy |
Distilling Task-Specific Knowledge from BERT into Simple Neural Networks |
https://arxiv.org/abs/1903.12136 | KD/text/BERT2LSTM |
Deep Mutual Learning | https://arxiv.org/abs/1706.00384 | KD/vision/DML |
The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks |
https://arxiv.org/abs/1803.03635 | Pruning/lottery_tickets |
Regularizing Class-wise Predictions via Self-knowledge Distillation |
https://arxiv.org/abs/2003.13964 | KD/vision/CSDK |
Please cite our pre-print if you find KD-Lib
useful in any way :)
@misc{shah2020kdlib,
title={KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization},
author={Het Shah and Avishree Khare and Neelay Shah and Khizir Siddiqui},
year={2020},
eprint={2011.14691},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
KD_Lib-0.0.32.tar.gz
(317.7 kB
view details)
Built Distribution
File details
Details for the file KD_Lib-0.0.32.tar.gz
.
File metadata
- Download URL: KD_Lib-0.0.32.tar.gz
- Upload date:
- Size: 317.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a3f864fdf5e748efe1386054812e51fe341544882fad1b699785f67f51e6eefa |
|
MD5 | 3cb85cdae02a2afcc593710512dcc608 |
|
BLAKE2b-256 | 6658300c3245390ef8a14b87b906485e8caaeaaa1664307bca56b8110a22eb3c |
File details
Details for the file KD_Lib-0.0.32-py2.py3-none-any.whl
.
File metadata
- Download URL: KD_Lib-0.0.32-py2.py3-none-any.whl
- Upload date:
- Size: 68.1 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 40f288b3dbbb6506df8159874787bd302cab0089ef8280e0522c1d1e53408df4 |
|
MD5 | b04ba6a5abe13851d6cb9011b47d84b5 |
|
BLAKE2b-256 | 24fecd566e70615002e80d4d268df72457ca3d74eccc441c099cf074f09eb83b |