Skip to main content

Knowledge Distillation for Large Language Models

Project description

llm_distil: Knowledge Distillation for Large Language Models

Python 3.8+ PyTorch License

A clean, production-ready library for distilling large language models using three knowledge distillation methods: KD, RevKD, and GKD.

Features

  • Three Distillation Methods:

    • KD (Knowledge Distillation): Standard forward KL divergence (mean-seeking)
    • RevKD (Reverse Knowledge Distillation): Reverse KL divergence (mode-seeking)
    • GKD (Generalized Knowledge Distillation): Generalized JSD with on-policy generation
  • Parameter-Efficient Fine-Tuning (PEFT):

    • LoRA: Low-Rank Adaptation (~0.1-1% trainable params)
    • QLoRA: Quantized LoRA with 4-bit/8-bit quantization
    • Prefix Tuning: Learn prefix vectors
    • Prompt Tuning: Learn soft prompts
    • IA3: Infused Adapter by Inhibiting and Amplifying Inner Activations
  • HuggingFace Integration: Built on top of transformers.Trainer for seamless workflow

  • Easy-to-Use API: Clean interfaces following best practices

  • Flexible Configuration: Dataclass-based configs with validation

  • Comprehensive Metrics: ROUGE, BLEU, perplexity tracking

Installation

pip install llm-distil

Or install from source:

git clone https://github.com/parmanu-lcs2/llm_distil.git
cd llm_distil
pip install -e .

For development with logging tools:

pip install -e ".[dev,logging]"

Optional: For PEFT support (LoRA, QLoRA, etc.):

pip install peft>=0.7.0 bitsandbytes>=0.41.0 accelerate>=0.24.0

Quick Start

Standard Knowledge Distillation (KD)

from llm_distil import KnowledgeDistillation, DistillationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load models
teacher = AutoModelForCausalLM.from_pretrained("gpt2-medium")
student = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # Required for GPT-2

# Configure distillation
config = DistillationConfig(
    teacher_model_name="gpt2-medium",
    student_model_name="gpt2",
    temperature=2.0,
    kd_loss_weight=0.5,
    epochs=3,  # Note: 'epochs' not 'num_train_epochs'
    batch_size=8,  # Note: 'batch_size' not 'per_device_train_batch_size'
    learning_rate=5e-5,
)

# Initialize and train (teacher auto-moves to student's device)
kd = KnowledgeDistillation(teacher, student, config)
kd.train(train_dataset, eval_dataset)

# Evaluate
metrics = kd.evaluate(test_dataset)
print(f"Perplexity: {metrics['perplexity']:.2f}")

# Save the distilled student
kd.save_student("./distilled_gpt2")

Reverse Knowledge Distillation (RevKD)

from llm_distil import ReverseKnowledgeDistillation

# Same setup as above, just use RevKD
revkd = ReverseKnowledgeDistillation(teacher, student, config)
revkd.train(train_dataset, eval_dataset)

Generalized Knowledge Distillation (GKD)

from llm_distil import GeneralizedKnowledgeDistillation, DistillationConfig

# GKD-specific config
config = DistillationConfig(
    teacher_model_name="gpt2-medium",
    student_model_name="gpt2",
    lambda_gkd=0.5,  # Mixture weight
    beta_gkd=0.5,    # On-policy weight
    epochs=3,  # Note: 'epochs' not 'num_train_epochs'
)

gkd = GeneralizedKnowledgeDistillation(teacher, student, config)
gkd.train(train_dataset, eval_dataset)

Parameter-Efficient Fine-Tuning with LoRA

from llm_distil import KnowledgeDistillation, DistillationConfig

# LoRA config - only train ~0.3M parameters instead of 117M!
config = DistillationConfig(
    teacher_model_name="gpt2-medium",
    student_model_name="gpt2",
    temperature=2.0,
    kd_loss_weight=0.5,
    epochs=3,
    use_peft=True,  # Enable PEFT
    peft_type="lora",  # Options: lora, qlora, prefix, prompt, ia3
    lora_r=8,  # LoRA rank
    lora_alpha=16,  # LoRA alpha
    lora_dropout=0.1
)

kd = KnowledgeDistillation(teacher, student, config)
kd.train(train_dataset, eval_dataset)

# Save only adapters (~few MB instead of ~500MB)
kd.save_student("./lora_adapters")

API Reference

Class Description Key Parameters
KnowledgeDistillation Standard forward KL temperature, kd_loss_weight
ReverseKnowledgeDistillation Reverse KL (mode-seeking) temperature, kd_loss_weight
GeneralizedKnowledgeDistillation JSD with on-policy lambda_gkd, beta_gkd
DistillationConfig Configuration dataclass All training hyperparameters

Comparison of Methods

Method Loss Function Behavior Best For
KD Forward KL: KL(Teacher || Student) Mean-seeking, covers all modes General-purpose distillation
RevKD Reverse KL: KL(Student || Teacher) Mode-seeking, focuses on peaks High-confidence predictions
GKD JSD: λ·JSD(T,S) + (1-λ)·JSD(T,S_gen) Mixture of off/on-policy Generative tasks

Temperature Scaling: KD and RevKD use temperature T to soften distributions. Loss is scaled by T² to preserve gradient magnitudes.

On-Policy Generation: GKD generates sequences from the student during training for more robust distillation.

Expected Results

On Databricks Dolly-15k (instruction-following dataset):

Quick Demo (200 examples, 1 epoch)

Model Perplexity Training Time Trainable Params Model Size
Teacher (GPT2-medium) ~100-150 - 355M 355M params
Student Baseline ~80-120 2-3 min 124M 124M params
Student + KD ~75-110 3-4 min 124M 124M params
Student + RevKD ~75-110 3-4 min 124M 124M params
Student + GKD ~75-110 4-5 min 124M 124M params
Student + LoRA ~75-110 2-3 min 0.3M (0.26%) 124M + 2MB

Quick demo results on single GPU (T4/A100)

LoRA Benefits: 99.7% fewer trainable parameters, 75% less memory, 95% storage savings

Full Training (1000 examples, 3 epochs)

Model Perplexity Training Time Model Size
Teacher (GPT2-medium) ~25.3 - 355M params
Student Baseline ~32.1 45 min 124M params
Student + KD ~28.7 52 min 124M params
Student + RevKD ~29.2 51 min 124M params
Student + GKD ~28.4 65 min 124M params

Full training results on single A100 GPU

Examples

Scripts

Notebooks

Documentation

See docs/API_GUIDE.md for detailed API documentation and advanced usage.

Citation

If you use this library, please cite:

@inproceedings{ramesh-etal-2025-generalization,
    title = "On the Generalization vs Fidelity Paradox in Knowledge Distillation",
    author = "Ramesh, Suhas Kamasetty  and
      Sengupta, Ayan  and
      Chakraborty, Tanmoy",
    editor = "Che, Wanxiang  and
      Nabende, Joyce  and
      Shutova, Ekaterina  and
      Pilehvar, Mohammad Taher",
    booktitle = "Findings of the Association for Computational Linguistics: ACL 2025",
    month = jul,
    year = "2025",
    address = "Vienna, Austria",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2025.findings-acl.923/",
    doi = "10.18653/v1/2025.findings-acl.923",
    pages = "17930--17951",
    ISBN = "979-8-89176-256-5",
}

License

Apache License 2.0 - see LICENSE file for details.

Contributing

Contributions welcome! Please open an issue or PR.

Acknowledgments

Built with HuggingFace Transformers and inspired by research in knowledge distillation for LLMs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llm_distil-0.1.0.tar.gz (42.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

llm_distil-0.1.0-py3-none-any.whl (21.6 kB view details)

Uploaded Python 3

File details

Details for the file llm_distil-0.1.0.tar.gz.

File metadata

  • Download URL: llm_distil-0.1.0.tar.gz
  • Upload date:
  • Size: 42.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for llm_distil-0.1.0.tar.gz
Algorithm Hash digest
SHA256 61ed0a8c44f990ab4afc17199f84d73d4fae32e650c27283e6b7c491ea6f8019
MD5 ffefd928d9540137e3670ef09ee87994
BLAKE2b-256 c6ae74f2c6af26945ff1e45fdcab36ba088d5108eb45dddcb40e5e25af404b33

See more details on using hashes here.

File details

Details for the file llm_distil-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: llm_distil-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 21.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for llm_distil-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e3633516c8f590d3dc3c910387dbf3810993a18ebea14d4f4e14c48cddc1a813
MD5 b042de199c4d413d86403d258cf93eef
BLAKE2b-256 f914280fa91b3d2ff3abfb8edc6e09b6a0b6ce5bf8e343ebd507098be1f2d810

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page