Skip to main content

A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization

Project description

KD-Lib

A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization

Installation

From source (recommended)

https://github.com/SforAiDl/KD_Lib.git
cd KD_Lib
python setup.py install

From PyPI

pip install KD-Lib

Example usage

To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot loss curves:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD

# This part is where you define your datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

teacher_model = <your model>
student_model = <your model>

teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)

# Now, this is where KD_Lib comes into the picture

distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader, 
                      teacher_optimizer, student_optimizer)  
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True)    # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True)    # Train the student network
distiller.evaluate(teacher=False)                                       # Evaluate the student network
distiller.get_parameters()                                              # A utility function to get the number of 
                                                                        # parameters in the  teacher and the student network

To train a collection of 3 models in an online fashion using the framework in Deep Mutual Learning and log training details to Tensorboard:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50          # To use models packaged in KD_Lib

# Define your datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)

student_cohort = [student_model_1, student_model_2]

student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)

student_optimizers = [student_optimizer_1, student_optimizer_2]

# Now, this is where KD_Lib comes into the picture 

distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./logs")

distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()

Methods Implemented

Some benchmark results can be found in the logs file.

Paper / Method Link Repository (KD_Lib/)
Distilling the Knowledge in a Neural Network https://arxiv.org/abs/1503.02531 KD/vision/vanilla
Improved Knowledge Distillation via Teacher Assistant https://arxiv.org/abs/1902.03393 KD/vision/TAKD
Relational Knowledge Distillation https://arxiv.org/abs/1904.05068 KD/vision/RKD
Distilling Knowledge from Noisy Teachers https://arxiv.org/abs/1610.09650 KD/vision/noisy
Paying More Attention To The Attention https://arxiv.org/abs/1612.03928 KD/vision/attention
Revisit Knowledge Distillation: a Teacher-free
Framework
https://arxiv.org/abs/1909.11723 KD/vision/teacher_free
Mean Teachers are Better Role Models https://arxiv.org/abs/1703.01780 KD/vision/mean_teacher
Knowledge Distillation via Route Constrained
Optimization
https://arxiv.org/abs/1904.09149 KD/vision/RCO
Born Again Neural Networks https://arxiv.org/abs/1805.04770 KD/vision/BANN
Preparing Lessons: Improve Knowledge Distillation
with Better Supervision
https://arxiv.org/abs/1911.07471 KD/vision/KA
Improving Generalization Robustness with Noisy
Collaboration in Knowledge Distillation
https://arxiv.org/abs/1910.05057 KD/vision/noisy
Distilling Task-Specific Knowledge from BERT into
Simple Neural Networks
https://arxiv.org/abs/1903.12136 KD/text/BERT2LSTM
Deep Mutual Learning https://arxiv.org/abs/1706.00384 KD/vision/DML
The Lottery Ticket Hypothesis: Finding Sparse,
Trainable Neural Networks
https://arxiv.org/abs/1803.03635 Pruning/lottery_tickets
Regularizing Class-wise Predictions via
Self-knowledge Distillation
https://arxiv.org/abs/2003.13964 KD/vision/CSDK

Please cite our pre-print if you find KD-Lib useful in any way :)

@misc{shah2020kdlib,
  title={KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization}, 
  author={Het Shah and Avishree Khare and Neelay Shah and Khizir Siddiqui},
  year={2020},
  eprint={2011.14691},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

KD_Lib-0.0.32.tar.gz (317.7 kB view details)

Uploaded Source

Built Distribution

KD_Lib-0.0.32-py2.py3-none-any.whl (68.1 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file KD_Lib-0.0.32.tar.gz.

File metadata

  • Download URL: KD_Lib-0.0.32.tar.gz
  • Upload date:
  • Size: 317.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.13

File hashes

Hashes for KD_Lib-0.0.32.tar.gz
Algorithm Hash digest
SHA256 a3f864fdf5e748efe1386054812e51fe341544882fad1b699785f67f51e6eefa
MD5 3cb85cdae02a2afcc593710512dcc608
BLAKE2b-256 6658300c3245390ef8a14b87b906485e8caaeaaa1664307bca56b8110a22eb3c

See more details on using hashes here.

File details

Details for the file KD_Lib-0.0.32-py2.py3-none-any.whl.

File metadata

  • Download URL: KD_Lib-0.0.32-py2.py3-none-any.whl
  • Upload date:
  • Size: 68.1 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.13

File hashes

Hashes for KD_Lib-0.0.32-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 40f288b3dbbb6506df8159874787bd302cab0089ef8280e0522c1d1e53408df4
MD5 b04ba6a5abe13851d6cb9011b47d84b5
BLAKE2b-256 24fecd566e70615002e80d4d268df72457ca3d74eccc441c099cf074f09eb83b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page