A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization
Project description
KD-Lib
A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization
Installation
From source (recommended)
https://github.com/SforAiDl/KD_Lib.git
cd KD_Lib
python setup.py install
From PyPI
pip install KD-Lib
Example usage
To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot loss curves:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD
# This part is where you define your datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
teacher_model = <your model>
student_model = <your model>
teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)
# Now, this is where KD_Lib comes into the picture
distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader,
teacher_optimizer, student_optimizer)
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True) # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True) # Train the student network
distiller.evaluate(teacher=False) # Evaluate the student network
distiller.get_parameters() # A utility function to get the number of
# parameters in the teacher and the student network
To train a collection of 3 models in an online fashion using the framework in Deep Mutual Learning and log training details to Tensorboard:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50 # To use models packaged in KD_Lib
# Define your datasets, dataloaders, models and optimizers
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"mnist_data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=32,
shuffle=True,
)
student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)
student_cohort = [student_model_1, student_model_2]
student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)
student_optimizers = [student_optimizer_1, student_optimizer_2]
# Now, this is where KD_Lib comes into the picture
distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./logs")
distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()
Methods Implemented
Some benchmark results can be found in the logs file.
| Paper / Method | Link | Repository (KD_Lib/) |
|---|---|---|
| Distilling the Knowledge in a Neural Network | https://arxiv.org/abs/1503.02531 | KD/vision/vanilla |
| Improved Knowledge Distillation via Teacher Assistant | https://arxiv.org/abs/1902.03393 | KD/vision/TAKD |
| Relational Knowledge Distillation | https://arxiv.org/abs/1904.05068 | KD/vision/RKD |
| Distilling Knowledge from Noisy Teachers | https://arxiv.org/abs/1610.09650 | KD/vision/noisy |
| Paying More Attention To The Attention | https://arxiv.org/abs/1612.03928 | KD/vision/attention |
| Revisit Knowledge Distillation: a Teacher-free Framework |
https://arxiv.org/abs/1909.11723 | KD/vision/teacher_free |
| Mean Teachers are Better Role Models | https://arxiv.org/abs/1703.01780 | KD/vision/mean_teacher |
| Knowledge Distillation via Route Constrained Optimization |
https://arxiv.org/abs/1904.09149 | KD/vision/RCO |
| Born Again Neural Networks | https://arxiv.org/abs/1805.04770 | KD/vision/BANN |
| Preparing Lessons: Improve Knowledge Distillation with Better Supervision |
https://arxiv.org/abs/1911.07471 | KD/vision/KA |
| Improving Generalization Robustness with Noisy Collaboration in Knowledge Distillation |
https://arxiv.org/abs/1910.05057 | KD/vision/noisy |
| Distilling Task-Specific Knowledge from BERT into Simple Neural Networks |
https://arxiv.org/abs/1903.12136 | KD/text/BERT2LSTM |
| Deep Mutual Learning | https://arxiv.org/abs/1706.00384 | KD/vision/DML |
| The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks |
https://arxiv.org/abs/1803.03635 | Pruning/lottery_tickets |
| Regularizing Class-wise Predictions via Self-knowledge Distillation |
https://arxiv.org/abs/2003.13964 | KD/vision/CSDK |
Please cite our pre-print if you find KD-Lib useful in any way :)
@misc{shah2020kdlib,
title={KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization},
author={Het Shah and Avishree Khare and Neelay Shah and Khizir Siddiqui},
year={2020},
eprint={2011.14691},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file KD_Lib-0.0.32.tar.gz.
File metadata
- Download URL: KD_Lib-0.0.32.tar.gz
- Upload date:
- Size: 317.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a3f864fdf5e748efe1386054812e51fe341544882fad1b699785f67f51e6eefa
|
|
| MD5 |
3cb85cdae02a2afcc593710512dcc608
|
|
| BLAKE2b-256 |
6658300c3245390ef8a14b87b906485e8caaeaaa1664307bca56b8110a22eb3c
|
File details
Details for the file KD_Lib-0.0.32-py2.py3-none-any.whl.
File metadata
- Download URL: KD_Lib-0.0.32-py2.py3-none-any.whl
- Upload date:
- Size: 68.1 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
40f288b3dbbb6506df8159874787bd302cab0089ef8280e0522c1d1e53408df4
|
|
| MD5 |
b04ba6a5abe13851d6cb9011b47d84b5
|
|
| BLAKE2b-256 |
24fecd566e70615002e80d4d268df72457ca3d74eccc441c099cf074f09eb83b
|