BERT Distillation
Project description
BertDistiller: Knowledge Distillation for BERT Models
A flexible framework for distilling BERT models using various distillation techniques, built on the Hugging Face Transformers library.
Currently implements:
- MiniLMv2: Multi-Head Self-Attention Relation Distillation for compressing pretrained Transformers.
Overview
BertDistiller enables knowledge distillation of BERT models using the MiniLMv2 technique - a task-agnostic approach that compresses large transformer models into smaller, faster models while maintaining comparable performance.
Key features:
- Built on Hugging Face Transformers: Seamless integration with the transformers ecosystem
- Task-agnostic distillation: Compress models without task-specific fine-tuning
- Flexible architecture: Configure student models with different layer counts and dimensions
- Teacher weight inheritance: Option to initialize student with teacher weights
Experimental Results
The following table compares our implementation's results with Microsoft's original MiniLM implementations on the GLUE benchmark:
| Model | STSB | RTE | CoLA | QQP | SST-2 | MNLI | QNLI | MRPC | Avg |
|---|---|---|---|---|---|---|---|---|---|
| MiniLM-L6-H768-distilled-from-BERT-Base | 88.66 | 67.11 | 72.90 | 87.18 | 91.55 | 83.58 | 90.20 | 89.17 | 83.79 |
| MiniLM-L6-H384-distilled-from-BERT-Base | 87.33 | 64.74 | 66.63 | 85.72 | 90.58 | 81.85 | 89.55 | 88.00 | 81.80 |
| Our Model (L6-H384) | 85.29 | 59.81 | 70.04 | 85.22 | 90.62 | 81.03 | 87.69 | 86.66 | 80.80 |
With just a 1% difference in average score, our model was trained with a maximum sequence length of 128 tokens (vs 512 in the original paper) and was distilled on a single RTX A6000 GPU, demonstrating the efficiency and accessibility of our approach.
Installation
pip install bertdistiller
Quick Start
See the examples/minilm_distillation.py for a complete working example. Here's a simplified version:
from bertdistiller import MiniLMTrainer, MiniLMTrainingArguments, create_student
from transformers import AutoModel, DataCollatorWithPadding
# 1. Create configuration
args = MiniLMTrainingArguments(
teacher_layer=12, # Which teacher layer to transfer from
student_layer=6, # Number of layers in student model
student_hidden_size=384, # Hidden size of student model
num_relation_heads=48, # Number of relation heads for distillation
relations={(1,1): 1.0, (2,2): 1.0, (3,3): 1.0}, # Q-Q, K-K, V-V relations
# Training parameters
output_dir="./output",
per_device_train_batch_size=256,
learning_rate=6e-4,
max_steps=400_000,
)
# 2. Create models & trainer
teacher = AutoModel.from_pretrained("google-bert/bert-base-uncased")
student = create_student("google-bert/bert-base-uncased", args)
trainer = MiniLMTrainer(
args=args,
teacher_model=teacher,
model=student,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# 3. Train and save
trainer.train()
student.save_pretrained("./distilled-model")
How MiniLMv2 Works
MiniLMv2 transfers knowledge using self-attention relations - interactions between query, key, and value vectors within transformer layers. The implementation:
- Computes relation patterns using scaled dot-product between Q-Q, K-K, and V-V pairs
- Creates flexible "relation heads" that don't require matching teacher/student attention head counts
- Strategically selects which teacher layer to distill from (typically layer 12 for base models, an upper-middle layer for large models)
This approach provides more fine-grained knowledge transfer than traditional attention distillation methods.
Evaluation
BertDistiller includes utilities to evaluate distilled models on GLUE benchmark tasks:
from bertdistiller.evaluation import evaluate, create_summary_table
# Evaluate on GLUE tasks
evaluate(
model_name_or_path="your-distilled-model",
tasks=["mnli", "qnli", "qqp", "sst2"],
learning_rate=[1e-5, 3e-5],
epochs=[3, 5],
)
# Generate comparison table
summary = create_summary_table("./evaluation_results")
print(summary)
Recommendations
- For base-size teachers (12 layers), use the last layer for distillation
- For large-size teachers (24 layers), use an upper-middle layer (e.g., layer 21)
- Using more relation heads (48+) generally improves performance
- Initialize with teacher weights when possible
Acknowledgements & Citation
Built using Hugging Face Transformers and inspired by minilmv2.bb implementation and the original MiniLMv2 paper:
@article{wang2020minilmv2,
title={MINILMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers},
author={Wang, Wenhui and Bao, Hangbo and Huang, Shaohan and Dong, Li and Wei, Furu},
journal={arXiv preprint arXiv:2012.15828},
year={2020}
}
License
Apache License Version 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bertdistiller-0.1.1.tar.gz.
File metadata
- Download URL: bertdistiller-0.1.1.tar.gz
- Upload date:
- Size: 25.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.1 CPython/3.10.13 Darwin/24.3.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
06ca0d74eb1a9a0b99153548975e7cdab08fecf0a411de92f5ab9d5e23e58ea0
|
|
| MD5 |
f62f587d14ce64cf0d73aa002498eec9
|
|
| BLAKE2b-256 |
5cd8bcc7cc57dcc4a6cc6c6cedce17743a1784c20d47f598ceb83823f3df6ff2
|
File details
Details for the file bertdistiller-0.1.1-py3-none-any.whl.
File metadata
- Download URL: bertdistiller-0.1.1-py3-none-any.whl
- Upload date:
- Size: 26.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.1 CPython/3.10.13 Darwin/24.3.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5a8e381e75b143628890b09eba300ce1fa70dd549865470f3531307bde3bcf37
|
|
| MD5 |
d3f59543689728ddb917567348af6ebc
|
|
| BLAKE2b-256 |
24c83636d4edf8246e1d74f63b9df853a06a3b5ca2a84024da616cd380c9abf8
|