Skip to main content

BERT Distillation

Project description

BertDistiller: Knowledge Distillation for BERT Models

HF Models Python Versions License

A flexible framework for distilling BERT models using various distillation techniques, built on the Hugging Face Transformers library.

Currently implements:

  • MiniLMv2: Multi-Head Self-Attention Relation Distillation for compressing pretrained Transformers.

Overview

BertDistiller enables knowledge distillation of BERT models using the MiniLMv2 technique - a task-agnostic approach that compresses large transformer models into smaller, faster models while maintaining comparable performance.

Key features:

  • Built on Hugging Face Transformers: Seamless integration with the transformers ecosystem
  • Task-agnostic distillation: Compress models without task-specific fine-tuning
  • Flexible architecture: Configure student models with different layer counts and dimensions
  • Teacher weight inheritance: Option to initialize student with teacher weights

Experimental Results

The following table compares our implementation's results with Microsoft's original MiniLM implementations on the GLUE benchmark:

Model STSB RTE CoLA QQP SST-2 MNLI QNLI MRPC Avg
MiniLM-L6-H768-distilled-from-BERT-Base 88.66 67.11 72.90 87.18 91.55 83.58 90.20 89.17 83.79
MiniLM-L6-H384-distilled-from-BERT-Base 87.33 64.74 66.63 85.72 90.58 81.85 89.55 88.00 81.80
Our Model (L6-H384) 85.29 59.81 70.04 85.22 90.62 81.03 87.69 86.66 80.80

With just a 1% difference in average score, our model was trained with a maximum sequence length of 128 tokens (vs 512 in the original paper) and was distilled on a single RTX A6000 GPU, demonstrating the efficiency and accessibility of our approach.

Installation

pip install bertdistiller

Quick Start

See the examples/minilm_distillation.py for a complete working example. Here's a simplified version:

from bertdistiller import MiniLMTrainer, MiniLMTrainingArguments, create_student
from transformers import AutoModel, DataCollatorWithPadding

# 1. Create configuration
args = MiniLMTrainingArguments(
    teacher_layer=12,                # Which teacher layer to transfer from
    student_layer=6,                 # Number of layers in student model
    student_hidden_size=384,         # Hidden size of student model
    num_relation_heads=48,           # Number of relation heads for distillation
    relations={(1,1): 1.0, (2,2): 1.0, (3,3): 1.0},  # Q-Q, K-K, V-V relations
    
    # Training parameters
    output_dir="./output",
    per_device_train_batch_size=256,
    learning_rate=6e-4,
    max_steps=400_000,
)

# 2. Create models & trainer
teacher = AutoModel.from_pretrained("google-bert/bert-base-uncased")
student = create_student("google-bert/bert-base-uncased", args)

trainer = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

# 3. Train and save
trainer.train()
student.save_pretrained("./distilled-model")

How MiniLMv2 Works

MiniLMv2 transfers knowledge using self-attention relations - interactions between query, key, and value vectors within transformer layers. The implementation:

  1. Computes relation patterns using scaled dot-product between Q-Q, K-K, and V-V pairs
  2. Creates flexible "relation heads" that don't require matching teacher/student attention head counts
  3. Strategically selects which teacher layer to distill from (typically layer 12 for base models, an upper-middle layer for large models)

This approach provides more fine-grained knowledge transfer than traditional attention distillation methods.

Evaluation

BertDistiller includes utilities to evaluate distilled models on GLUE benchmark tasks:

from bertdistiller.evaluation import evaluate, create_summary_table

# Evaluate on GLUE tasks
evaluate(
    model_name_or_path="your-distilled-model",
    tasks=["mnli", "qnli", "qqp", "sst2"],
    learning_rate=[1e-5, 3e-5],
    epochs=[3, 5],
)

# Generate comparison table
summary = create_summary_table("./evaluation_results")
print(summary)

Recommendations

  • For base-size teachers (12 layers), use the last layer for distillation
  • For large-size teachers (24 layers), use an upper-middle layer (e.g., layer 21)
  • Using more relation heads (48+) generally improves performance
  • Initialize with teacher weights when possible

Acknowledgements & Citation

Built using Hugging Face Transformers and inspired by minilmv2.bb implementation and the original MiniLMv2 paper:

@article{wang2020minilmv2,
  title={MINILMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers},
  author={Wang, Wenhui and Bao, Hangbo and Huang, Shaohan and Dong, Li and Wei, Furu},
  journal={arXiv preprint arXiv:2012.15828},
  year={2020}
}

License

Apache License Version 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bertdistiller-0.1.2.tar.gz (25.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bertdistiller-0.1.2-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file bertdistiller-0.1.2.tar.gz.

File metadata

  • Download URL: bertdistiller-0.1.2.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.1 CPython/3.10.13 Darwin/24.3.0

File hashes

Hashes for bertdistiller-0.1.2.tar.gz
Algorithm Hash digest
SHA256 ad24e19c0b0a827eca06fb5f94071520b3ab68103c98ddd1a0413063de880405
MD5 428b7462abf0ffbbfea7685b468d6e96
BLAKE2b-256 a957c04a1ec951ac27d025962804516738f83db67b175d1cba83ad6b7075bc40

See more details on using hashes here.

File details

Details for the file bertdistiller-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: bertdistiller-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 26.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.1 CPython/3.10.13 Darwin/24.3.0

File hashes

Hashes for bertdistiller-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 22a4323267f59c93bdea5ea5ee0fcc0f251f722185f50869e697c30a55d6a401
MD5 4acd6211a072114b5f6ff064408a854b
BLAKE2b-256 b6dc050333b093e2347d1a3c72faf9f1ceb73bb70cfef962355057a576522e2a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page