A comprehensive knowledge distillation training framework for transformer models
Project description
Distil Trainer
A comprehensive knowledge distillation training framework for transformer models.
Features
-
7 Distillation Strategies:
- Classical Embedding Distillation (MSE/Cosine loss)
- Layer Reduction (Depth Pruning)
- Width Pruning (Hidden size, Attention heads, MLP)
- Combined Depth-Width Pruning
- Multilingual Model Extension
- LLM to Embedding Model Conversion
- Reasoning/Chain-of-Thought Distillation
-
Flexible Architecture:
- Support for SentenceTransformers and HuggingFace models
- Multiple loss functions (MSE, KL Divergence, Cosine, Ranking)
- Configurable importance estimation for pruning
- PCA projection for dimension reduction
-
Production Ready:
- Export to HuggingFace Hub
- ONNX export support
- Distributed training with Accelerate
- Comprehensive evaluation framework
Installation
pip install distil-trainer
With optional dependencies:
# For experiment tracking
pip install distil-trainer[tracking]
# For model export
pip install distil-trainer[export]
# For MTEB evaluation
pip install distil-trainer[evaluation]
# For all features
pip install distil-trainer[all]
Quick Start
Basic Embedding Distillation
Distill knowledge from a large teacher model to a smaller student model:
from distil_trainer import DistilTrainer, DistilTrainerConfig, DistillationConfig
config = DistilTrainerConfig(
teacher_model="sentence-transformers/all-mpnet-base-v2",
student_model="sentence-transformers/paraphrase-TinyBERT-L6-v2",
distillation_config=DistillationConfig(
loss_type="mse", # Options: mse, kl_divergence, cosine, ranking, combined
use_pca_projection=True, # When student dim < teacher dim
),
output_dir="./distilled_model",
)
trainer = DistilTrainer(config)
trainer.load_data(train_data="sentence-transformers/all-nli")
trainer.train()
trainer.save_model("./final_model")
trainer.save_model("./final_model")
Custom Dataset Columns
If your dataset uses different column names (e.g., "text" instead of "sentence"), you can specify this when loading data:
# Load dataset with custom text column
trainer.load_data(
train_data="alibayram/cosmos-corpus-00-5",
text_column="text"
)
Layer Reduction (Depth Pruning)
Reduce model depth by keeping only selected layers:
from distil_trainer import DistilTrainer, DistilTrainerConfig
from distil_trainer.core.config import LayerReductionConfig
config = DistilTrainerConfig(
teacher_model="mixedbread-ai/mxbai-embed-large-v1",
student_init_strategy="layer_reduction",
pruning_config=LayerReductionConfig(
# Explicitly specify layers to keep (0-indexed)
layers_to_keep=[0, 3, 6, 9, 12, 15, 18, 21],
# Or use automatic selection
# num_layers_to_keep=8,
# layer_selection="importance", # Options: first, last, even, importance, custom
),
output_dir="./layer_reduced_model",
)
trainer = DistilTrainer(config)
trainer.train()
Width Pruning
Reduce hidden dimensions, attention heads, and intermediate sizes:
from distil_trainer import DistilTrainer, DistilTrainerConfig
from distil_trainer.core.config import WidthPruningConfig
config = DistilTrainerConfig(
teacher_model="Qwen/Qwen3-8B",
student_init_strategy="width_pruning",
pruning_config=WidthPruningConfig(
# Target absolute dimensions
target_hidden_size=3072,
target_intermediate_size=9216,
target_num_attention_heads=24,
target_num_key_value_heads=4,
# Or use ratios (alternative to absolute values)
# hidden_size_ratio=0.75,
# intermediate_size_ratio=0.75,
# Importance estimation method
importance_method="activation", # Options: activation, gradient, taylor, wanda, cosine_similarity
calibration_samples=1024,
),
output_dir="./width_pruned_model",
)
trainer = DistilTrainer(config)
trainer.train()
Combined Depth-Width Pruning
Apply both depth and width pruning for maximum compression:
from distil_trainer import DistilTrainer, DistilTrainerConfig
from distil_trainer.core.config import (
CombinedPruningConfig,
LayerReductionConfig,
WidthPruningConfig,
)
config = DistilTrainerConfig(
teacher_model="meta-llama/Llama-3.2-3B",
student_init_strategy="combined_pruning",
pruning_config=CombinedPruningConfig(
depth_config=LayerReductionConfig(
num_layers_to_keep=16,
layer_selection="importance",
),
width_config=WidthPruningConfig(
hidden_size_ratio=0.75,
intermediate_size_ratio=0.75,
),
pruning_order="depth_first", # Options: depth_first, width_first, interleaved
num_iterations=1,
),
output_dir="./compressed_model",
)
trainer = DistilTrainer(config)
trainer.train()
Multilingual Model Extension
Extend a monolingual model to support multiple languages:
from distil_trainer.core.config import MultilingualConfig
from distil_trainer.distillation import MultilingualDistillationStrategy
config = MultilingualConfig(
# Teacher understands these languages
source_languages=["en"],
# Student should learn these languages
target_languages=["de", "es", "fr", "it", "pt", "zh", "ja", "ko"],
# Student model (multilingual encoder)
student_model="xlm-roberta-base",
student_max_seq_length=128,
# Parallel sentence datasets for training
parallel_datasets=[
"sentence-transformers/parallel-sentences-talks",
"sentence-transformers/parallel-sentences-tatoeba",
],
max_sentences_per_language=500000,
# Training settings
num_train_epochs=5,
evaluation_steps=5000,
)
strategy = MultilingualDistillationStrategy(
teacher_model="paraphrase-distilroberta-base-v2",
config=config,
)
strategy.train()
Configuration Reference
DistilTrainerConfig
Main configuration class for distillation training:
| Parameter | Type | Default | Description |
|---|---|---|---|
teacher_model |
str/Model | Required | Teacher model name or instance |
student_model |
str/Model | None | Student model (None creates from teacher) |
student_init_strategy |
str | "from_pretrained" | How to initialize student |
pruning_config |
PruningConfig | None | Pruning configuration |
distillation_config |
DistillationConfig | Default | Distillation loss settings |
training_config |
TrainingConfig | Default | Training hyperparameters |
output_dir |
str | "./distilled_model" | Output directory |
device |
str | "auto" | Device to use |
precision |
str | "bf16" | Training precision |
DistillationConfig
Configuration for distillation losses:
| Parameter | Type | Default | Description |
|---|---|---|---|
loss_type |
str | "mse" | Loss function: mse, kl_divergence, cosine, ranking, combined |
logit_loss_weight |
float | 1.0 | Weight for logit distillation loss |
embedding_loss_weight |
float | 1.0 | Weight for embedding distillation loss |
intermediate_loss_weight |
float | 0.0 | Weight for intermediate layer loss |
temperature |
float | 1.0 | Temperature for KL divergence |
use_pca_projection |
bool | True | Use PCA when student dim < teacher dim |
precompute_teacher_embeddings |
bool | True | Cache teacher embeddings |
TrainingConfig
Training hyperparameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
num_train_epochs |
int | 1 | Number of training epochs |
per_device_train_batch_size |
int | 64 | Batch size per device |
learning_rate |
float | 1e-4 | Initial learning rate |
lr_scheduler_type |
str | "cosine" | LR scheduler type |
warmup_ratio |
float | 0.1 | Warmup ratio |
weight_decay |
float | 0.01 | Weight decay |
max_grad_norm |
float | 1.0 | Gradient clipping |
eval_strategy |
str | "steps" | Evaluation strategy |
eval_steps |
int | 500 | Evaluation frequency |
Evaluation
The framework includes comprehensive evaluation tools:
from distil_trainer.evaluation import (
EmbeddingSimilarityEvaluator,
MSEEvaluator,
BenchmarkRunner,
)
# Evaluate embedding similarity between teacher and student
evaluator = EmbeddingSimilarityEvaluator(
teacher_model=teacher,
student_model=student,
)
results = evaluator.evaluate(test_sentences)
# Run MTEB benchmarks (requires distil-trainer[evaluation])
runner = BenchmarkRunner(
model="./distilled_model",
tasks=["STS12", "STS13", "STS14", "STS15", "STS16"],
)
benchmark_results = runner.run()
Data Loading
Support for multiple data formats:
from distil_trainer.data import (
DistillationDataModule,
SentenceDistillationDataset,
TripletDataset,
ParallelSentencesDataset,
)
# Load from HuggingFace datasets
data_module = DistillationDataModule(
train_data="sentence-transformers/all-nli",
eval_data="sentence-transformers/stsb",
text_column="sentence",
max_seq_length=512,
num_workers=4,
)
# Or use triplet format for ranking loss
triplet_dataset = TripletDataset(
data_path="path/to/triplets.jsonl",
query_column="query",
positive_column="positive",
negative_column="negative",
)
Best Practices
Based on NVIDIA Minitron research:
- Sizing: Train largest model first, then prune and distill iteratively
- Pruning: Prefer width over depth pruning for better accuracy
- Retraining: Use distillation loss exclusively (not conventional training)
- Loss Selection:
- Logit + intermediate + embedding when depth is reduced significantly
- Logit-only when depth isn't reduced significantly
Recommended Workflow
# Step 1: Start with importance estimation
from distil_trainer.pruning import ImportanceEstimator
estimator = ImportanceEstimator(model, method="activation")
importance_scores = estimator.estimate(calibration_data)
# Step 2: Apply pruning based on importance
from distil_trainer.pruning import WidthPruner
pruner = WidthPruner(model, importance_scores)
pruned_model = pruner.prune(target_hidden_size=2048)
# Step 3: Distill knowledge from teacher
config = DistilTrainerConfig(
teacher_model=original_model,
student_model=pruned_model,
distillation_config=DistillationConfig(
loss_type="combined",
logit_loss_weight=1.0,
embedding_loss_weight=0.5,
),
)
trainer = DistilTrainer(config)
trainer.train()
License
MIT License - see LICENSE for details.
Citation
If you use this library in your research, please cite:
@software{distil_trainer,
title = {Distil Trainer: A Comprehensive Knowledge Distillation Framework},
author = {Ali Bayram},
year = {2025},
url = {https://github.com/malibayram/distil-trainer}
}
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file distil_trainer-0.1.27.tar.gz.
File metadata
- Download URL: distil_trainer-0.1.27.tar.gz
- Upload date:
- Size: 48.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dfb16d7ca8b6666ccae342af94847bd1b60b0c9dd85a5411f51dfc4d6a2e2183
|
|
| MD5 |
6082bf87df902a348e7bf62b0e044686
|
|
| BLAKE2b-256 |
31ad250f32512304d0507850d39ec34dd3dc59d478c32dd02343f54725d38206
|
File details
Details for the file distil_trainer-0.1.27-py3-none-any.whl.
File metadata
- Download URL: distil_trainer-0.1.27-py3-none-any.whl
- Upload date:
- Size: 65.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
466ea02026c03c86bbac5747575611bfeefba2d8ad1659bc93c2e0d6c62983cc
|
|
| MD5 |
1a91e746cba12c3273b11bf0850656c1
|
|
| BLAKE2b-256 |
fb92bb68ead4ba2688e71ed85a9cdb5684184e4b3bc134eb50a6a969e4771548
|