Skip to main content

A Modular, Configuration-Driven Framework for Knowledge Distillation. Trained models, training logs and configurations are available for ensuring the reproducibiliy.

Project description

torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation

PyPI version Build Status DOI:10.1007/978-3-030-76423-4_3

torchdistill (formerly kdkit) offers various state-of-the-art knowledge distillation methods and enables you to design (new) experiments simply by editing a declarative yaml config file instead of Python code. Even when you need to extract intermediate representations in teacher/student models, you will NOT need to reimplement the models, that often change the interface of the forward, but instead specify the module path(s) in the yaml file. Refer to this paper for more details.

In addition to knowledge distillation, this framework helps you design and perform general deep learning experiments (WITHOUT coding) for reproducible deep learning studies. i.e., it enables you to train models without teachers simply by excluding teacher entries from a declarative yaml config file. You can find such examples below and in configs/sample/.

Forward hook manager

Using ForwardHookManager, you can extract intermediate representations in model without modifying the interface of its forward function.
This example notebook Open In Colab will give you a better idea of the usage such as knowledge distillation and analysis of intermediate representations.

1 experiment → 1 declarative PyYAML config file

In torchdistill, many components and PyTorch modules are abstracted e.g., models, datasets, optimizers, losses, and more! You can define them in a declarative PyYAML config file so that can be seen as a summary of your experiment, and in many cases, you will NOT need to write Python code at all. Take a look at some configurations available in configs/. You'll see what modules are abstracted and how they are defined in a declarative PyYAML config file to design an experiment.

Top-1 validation accuracy for ILSVRC 2012 (ImageNet)

T: ResNet-34* Pretrained KD AT FT CRD Tf-KD SSKD L2 PAD-L2 KR
S: ResNet-18 69.76* 71.37 70.90 71.56 70.93 70.52 70.09 71.08 71.71 71.64
Original work N/A N/A 70.70 71.43** 71.17 70.42 71.62 70.90 71.71 71.61

* The pretrained ResNet-34 and ResNet-18 are provided by torchvision.
** FT is assessed with ILSVRC 2015 in the original work.
For the 2nd row (S: ResNet-18), most of the results are reported in this paper, and their checkpoints (trained weights), configuration and log files are available, and the configurations reuse the hyperparameters such as number of epochs used in the original work except for KD.

Examples

Executable code can be found in examples/ such as

For CIFAR-10 and CIFAR-100, some models are reimplemented and available as pretrained models in torchdistill. More details can be found here.

Some Transformer models fine-tuned by torchdistill for GLUE tasks are available at Hugging Face Model Hub. Sample GLUE benchmark results and details can be found here.

Google Colab Examples

The following examples are available in demo/. Note that the examples are for Google Colab users. Usually, examples/ would be a better reference if you have your own GPU(s).

CIFAR-10 and CIFAR-100

  • Training without teacher models Open In Colab
  • Knowledge distillation Open In Colab

GLUE

  • Fine-tuning without teacher models Open In Colab
  • Knowledge distillation Open In Colab

These examples write out test prediction files for you to see the test performance at the GLUE leaderboard system.

PyTorch Hub

If you find models on PyTorch Hub or GitHub repositories supporting PyTorch Hub, you can import them as teacher/student models simply by editing a declarative yaml config file.

e.g., If you use a pretrained ResNeSt-50 available in rwightman/pytorch-image-models (aka timm) as a teacher model for ImageNet dataset, you can import the model via PyTorch Hub with the following entry in your declarative yaml config file.

models:
  teacher_model:
    name: 'resnest50d'
    repo_or_dir: 'rwightman/pytorch-image-models'
    params:
      num_classes: 1000
      pretrained: True

How to setup

  • Python 3.6 >=
  • pipenv (optional)

Install by pip/pipenv

pip3 install torchdistill
# or use pipenv
pipenv install torchdistill

Install from this repository

git clone https://github.com/yoshitomo-matsubara/torchdistill.git
cd torchdistill/
pip3 install -e .
# or use pipenv
pipenv install "-e ."

Issues / Questions / Requests

The documentation is work-in-progress. In the meantime, feel free to create an issue if you find a bug.
If you have either a question or feature request, start a new discussion here.

Citation

If you use torchdistill in your research, please cite the following paper.
[Paper] [Preprint]

@inproceedings{matsubara2021torchdistill,
  title={torchdistill: A Modular, Configuration-Driven Framework for Knowledge Distillation},
  author={Matsubara, Yoshitomo},
  booktitle={International Workshop on Reproducible Research in Pattern Recognition},
  pages={24--44},
  year={2021},
  organization={Springer}
}

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchdistill-0.2.7.tar.gz (65.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchdistill-0.2.7-py3-none-any.whl (80.6 kB view details)

Uploaded Python 3

File details

Details for the file torchdistill-0.2.7.tar.gz.

File metadata

  • Download URL: torchdistill-0.2.7.tar.gz
  • Upload date:
  • Size: 65.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for torchdistill-0.2.7.tar.gz
Algorithm Hash digest
SHA256 fc03247af8e97b5280d290a5fbd1793c67cf127c675a3f94c6020539b2d8d061
MD5 81fe40ef14d63ab1820eb8486f8af953
BLAKE2b-256 6a74569328bb6029e890ebcf2946ce2096c2ab2e7daa1b97068bc460f59a3ae3

See more details on using hashes here.

File details

Details for the file torchdistill-0.2.7-py3-none-any.whl.

File metadata

  • Download URL: torchdistill-0.2.7-py3-none-any.whl
  • Upload date:
  • Size: 80.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for torchdistill-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 016cb776cd3eaee452a4090e901de9acdf3b4a72b9f4883001cce0b3a3143e96
MD5 7195d03fd20dbe5dd80de20108befed0
BLAKE2b-256 920ad90947c8ea1e2a7762153ab0b56740c9ff05013e1d4211ad6e8e4e02ebbb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page