Skip to main content

BERT Distillation

Project description

BertDistiller: Knowledge Distillation for BERT Models

HF Models Python Versions License

A flexible framework for distilling BERT models using various distillation techniques, built on the Hugging Face Transformers library.

Currently implements:

  • MiniLMv2: Multi-Head Self-Attention Relation Distillation for compressing pretrained Transformers.

Overview

BertDistiller enables knowledge distillation of BERT models using the MiniLMv2 technique - a task-agnostic approach that compresses large transformer models into smaller, faster models while maintaining comparable performance.

Key features:

  • Built on Hugging Face Transformers: Seamless integration with the transformers ecosystem
  • Task-agnostic distillation: Compress models without task-specific fine-tuning
  • Flexible architecture: Configure student models with different layer counts and dimensions
  • Teacher weight inheritance: Option to initialize student with teacher weights

Experimental Results

The following table compares our implementation's results with Microsoft's original MiniLM implementations on the GLUE benchmark:

Model STSB RTE CoLA QQP SST-2 MNLI QNLI MRPC Avg
MiniLM-L6-H768-distilled-from-BERT-Base 88.66 67.11 72.90 87.18 91.55 83.58 90.20 89.17 83.79
MiniLM-L6-H384-distilled-from-BERT-Base 87.33 64.74 66.63 85.72 90.58 81.85 89.55 88.00 81.80
Our Model (L6-H384) 85.29 59.81 70.04 85.22 90.62 81.03 87.69 86.66 80.80

With just a 1% difference in average score, our model was trained with a maximum sequence length of 128 tokens (vs 512 in the original paper) and was distilled on a single RTX A6000 GPU, demonstrating the efficiency and accessibility of our approach.

Installation

pip install bertdistiller

Quick Start

See the examples/minilm_distillation.py for a complete working example. Here's a simplified version:

from bertdistiller import MiniLMTrainer, MiniLMTrainingArguments, create_student
from transformers import AutoModel, DataCollatorWithPadding

# 1. Create configuration
args = MiniLMTrainingArguments(
    teacher_layer=12,                # Which teacher layer to transfer from
    student_layer=6,                 # Number of layers in student model
    student_hidden_size=384,         # Hidden size of student model
    num_relation_heads=48,           # Number of relation heads for distillation
    relations={(1,1): 1.0, (2,2): 1.0, (3,3): 1.0},  # Q-Q, K-K, V-V relations
    
    # Training parameters
    output_dir="./output",
    per_device_train_batch_size=256,
    learning_rate=6e-4,
    max_steps=400_000,
)

# 2. Create models & trainer
teacher = AutoModel.from_pretrained("google-bert/bert-base-uncased")
student = create_student("google-bert/bert-base-uncased", args)

trainer = MiniLMTrainer(
    args=args,
    teacher_model=teacher,
    model=student,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

# 3. Train and save
trainer.train()
student.save_pretrained("./distilled-model")

How MiniLMv2 Works

MiniLMv2 transfers knowledge using self-attention relations - interactions between query, key, and value vectors within transformer layers. The implementation:

  1. Computes relation patterns using scaled dot-product between Q-Q, K-K, and V-V pairs
  2. Creates flexible "relation heads" that don't require matching teacher/student attention head counts
  3. Strategically selects which teacher layer to distill from (typically layer 12 for base models, an upper-middle layer for large models)

This approach provides more fine-grained knowledge transfer than traditional attention distillation methods.

Evaluation

BertDistiller includes utilities to evaluate distilled models on GLUE benchmark tasks:

from bertdistiller.evaluation import evaluate, create_summary_table

# Evaluate on GLUE tasks
evaluate(
    model_name_or_path="your-distilled-model",
    tasks=["mnli", "qnli", "qqp", "sst2"],
    learning_rate=[1e-5, 3e-5],
    epochs=[3, 5],
)

# Generate comparison table
summary = create_summary_table("./evaluation_results")
print(summary)

Recommendations

  • For base-size teachers (12 layers), use the last layer for distillation
  • For large-size teachers (24 layers), use an upper-middle layer (e.g., layer 21)
  • Using more relation heads (48+) generally improves performance
  • Initialize with teacher weights when possible

Acknowledgements & Citation

Built using Hugging Face Transformers and inspired by minilmv2.bb implementation and the original MiniLMv2 paper:

@article{wang2020minilmv2,
  title={MINILMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers},
  author={Wang, Wenhui and Bao, Hangbo and Huang, Shaohan and Dong, Li and Wei, Furu},
  journal={arXiv preprint arXiv:2012.15828},
  year={2020}
}

License

Apache License Version 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bertdistiller-0.1.1.tar.gz (25.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bertdistiller-0.1.1-py3-none-any.whl (26.1 kB view details)

Uploaded Python 3

File details

Details for the file bertdistiller-0.1.1.tar.gz.

File metadata

  • Download URL: bertdistiller-0.1.1.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.1 CPython/3.10.13 Darwin/24.3.0

File hashes

Hashes for bertdistiller-0.1.1.tar.gz
Algorithm Hash digest
SHA256 06ca0d74eb1a9a0b99153548975e7cdab08fecf0a411de92f5ab9d5e23e58ea0
MD5 f62f587d14ce64cf0d73aa002498eec9
BLAKE2b-256 5cd8bcc7cc57dcc4a6cc6c6cedce17743a1784c20d47f598ceb83823f3df6ff2

See more details on using hashes here.

File details

Details for the file bertdistiller-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: bertdistiller-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 26.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.1 CPython/3.10.13 Darwin/24.3.0

File hashes

Hashes for bertdistiller-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5a8e381e75b143628890b09eba300ce1fa70dd549865470f3531307bde3bcf37
MD5 d3f59543689728ddb917567348af6ebc
BLAKE2b-256 24c83636d4edf8246e1d74f63b9df853a06a3b5ca2a84024da616cd380c9abf8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page