Skip to main content

A Modular, Configuration-Driven Framework for Knowledge Distillation. Trained models, training logs and configurations are available for ensuring the reproducibiliy.

Project description

torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation

PyPI version Build Status

torchdistill (formerly kdkit) offers various knowledge distillation methods and enables you to design (new) experiments simply by editing a yaml file instead of Python code. Even when you need to extract intermediate representations in teacher/student models, you will NOT need to reimplement the models, that often change the interface of the forward, but instead specify the module path(s) in the yaml file. Refer to this paper for more details.

Forward hook manager

Using ForwardHookManager, you can extract intermediate representations in model without modifying the interface of its forward function.
This example notebook will give you a better idea of the usage.

1 experiment → 1 PyYAML config file

In torchdistill, many components and PyTorch modules are abstracted e.g., models, datasets, optimizers, losses, and more! You can define them in a PyYAML config file so that can be seen as a summary of your experiment, and in many cases, you will NOT need to write Python code at all. Take a look at some configurations available in configs/. You'll see what modules are abstracted and how they are defined in a PyYAML config file to design an experiment.

Top-1 validation accuracy for ILSVRC 2012 (ImageNet)

T: ResNet-34* Pretrained KD AT FT CRD Tf-KD SSKD L2 PAD-L2
S: ResNet-18 69.76* 71.37 70.90 71.56 70.93 70.52 70.09 71.08 71.71
Original work N/A N/A 70.70 71.43** 71.17 70.42 71.62 70.90 71.71

* The pretrained ResNet-34 and ResNet-18 are provided by torchvision.
** FT is assessed with ILSVRC 2015 in the original work.
For the 2nd row (S: ResNet-18), the checkpoint (trained weights), configuration and log files are available, and the configurations reuse the hyperparameters such as number of epochs used in the original work except for KD.

Examples

Executable code can be found in examples/ such as

For CIFAR-10 and CIFAR-100, some models are reimplemented and available as pretrained models in torchdistill. More details can be found here.

Google Colab Examples

CIFAR-10 and CIFAR-100

  • Training without teacher models Open In Colab
  • Knowledge distillation Open In Colab

These examples are available in demo/. Note that the examples are for Google Colab users, and usually examples/ would be a better reference if you have your own GPU(s).

PyTorch Hub

If you find models on PyTorch Hub or GitHub repositories supporting PyTorch Hub, you can import them as teacher/student models simply by editing a yaml config file.

e.g., If you use a pretrained ResNeSt-50 available in rwightman/pytorch-image-models as a teacher model for ImageNet dataset, you can import the model via PyTorch Hub with the following entry in your yaml config file.

models:
  teacher_model:
    name: 'resnest50d'
    repo_or_dir: 'rwightman/pytorch-image-models'
    params:
      num_classes: 1000
      pretrained: True

How to setup

  • Python 3.6 >=
  • pipenv (optional)

Install by pip/pipenv

pip3 install torchdistill
# or use pipenv
pipenv install torchdistill

Install from this repository

git clone https://github.com/yoshitomo-matsubara/torchdistill.git
cd torchdistill/
pip3 install -e .
# or use pipenv
pipenv install "-e ."

Issues / Contact

The documentation is work-in-progress. In the meantime, feel free to create an issue if you have a feature request or email me ( yoshitom@uci.edu ) if you would like to ask me in private.

Citation

[Preprint]

@article{matsubara2020torchdistill,
  title={torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation},
  author={Matsubara, Yoshitomo},
  year={2020}
  eprint={2011.12913},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchdistill, version 0.1.6
Filename, size File type Python version Upload date Hashes
Filename, size torchdistill-0.1.6-py3-none-any.whl (77.3 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size torchdistill-0.1.6.tar.gz (60.8 kB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page