Skip to main content

Meta-Policy Knowledge Distillation for compact student models

Project description

MPDistil 🎓

Meta-Policy Knowledge Distillation for Training Compact Student Models

Python 3.8+ PyTorch License: MIT

MPDistil is a teacher-student collaborative knowledge distillation framework that enables compact student models to outperform larger teacher models through meta-learning and curriculum learning.

Based on the ICLR 2024 paper: A Good Learner can Teach Better: Teacher-Student Collaborative Knowledge Distillation

🌟 Key Features

  • 📊 Superior Performance: 6-layer BERT student outperforms 12-layer BERT teacher on 5/6 SuperGLUE tasks
  • 🎯 4-Phase Training: Teacher fine-tuning → PKD → Meta-teacher → Curriculum learning
  • 🚀 Simple API: Easy-to-use .train() method with full control over all phases
  • 📏 Flexible Metrics: Built-in support for accuracy, F1, MCC, correlation via HuggingFace evaluate
  • 🔧 Customizable: Works with any HuggingFace model and custom datasets
  • 💻 Colab-Ready: GPU-optimized for Google Colab environments
  • 📦 Easy Installation: Single pip command to get started

📈 Methodology

methodology

MPDistil consists of 4 training phases:

  1. Teacher Fine-tuning: Fine-tune teacher model on the target task
  2. Student PKD: Knowledge distillation with Patient Knowledge Distillation
  3. Meta-Teacher Learning: Collaborative or competitive loss for meta-learning
  4. Curriculum Learning: Reinforcement learning-based task selection

🚀 Installation

From GitHub (Recommended)

pip install git+https://github.com/parmanu-lcs2/mpdistil.git

From Source

git clone https://github.com/parmanu-lcs2/mpdistil.git
cd mpdistil
pip install -e .

💡 Quick Start

Basic Usage

from mpdistil import MPDistil, load_superglue_dataset

# Load data
loaders, num_labels = load_superglue_dataset('CB')

# Initialize MPDistil
model = MPDistil(
    task_name='CB',
    num_labels=num_labels,
    metric='f1',  # Options: 'accuracy', 'f1', 'mcc', 'correlation', 'auto'
    teacher_model='bert-base-uncased',
    student_model='bert-base-uncased',
    student_layers=6
)

# Train with all 4 phases
history = model.train(
    train_loader=loaders['train'],
    val_loader=loaders['val'],
    teacher_epochs=10,   # Phase 1
    student_epochs=10,   # Phase 2
    meta_epochs=1        # Phase 3 (NEW!)
)

# Save trained student
model.save_student('./my_student_model')

# Make predictions
predictions = model.predict(loaders['test'])

With Custom Data

from mpdistil import MPDistil, create_simple_dataloader

# Prepare your data
texts = [("This is text A", "This is text B"), ...]
labels = [0, 1, 0, ...]

# Create DataLoader
train_loader = create_simple_dataloader(
    texts=texts,
    labels=labels,
    tokenizer_name='bert-base-uncased',
    max_length=128,
    batch_size=8
)

# Train model
model = MPDistil(task_name='MyTask', num_labels=2, metric='accuracy')
history = model.train(train_loader, val_loader)

With Meta-Learning (Curriculum)

# Load multiple tasks for curriculum learning
cb_loaders, _ = load_superglue_dataset('CB')
rte_loaders, _ = load_superglue_dataset('RTE')
boolq_loaders, _ = load_superglue_dataset('BoolQ')

# Train with curriculum learning
history = model.train(
    train_loader=cb_loaders['train'],
    val_loader=cb_loaders['val'],
    meta_loaders={
        'RTE': rte_loaders['val'],
        'BoolQ': boolq_loaders['val']
    },
    teacher_epochs=10,   # Phase 1
    student_epochs=10,   # Phase 2  
    meta_epochs=3,       # Phase 3 - can train for multiple epochs!
    num_episodes=200     # Phase 4 - curriculum learning episodes
)

📖 API Reference

MPDistil Class

Constructor

MPDistil(
    task_name: str,              # Name of the main task
    num_labels: int,             # Number of output classes
    metric: str = 'accuracy',    # Metric: 'accuracy', 'f1', 'mcc', 'correlation', 'auto'
    teacher_model: str = 'bert-base-uncased',  # HuggingFace model name
    student_model: str = 'bert-base-uncased',  # HuggingFace model name
    student_layers: int = 6,     # Number of layers for student
    device: str = 'auto',        # 'auto', 'cuda', or 'cpu'
    output_dir: str = './mpdistil_outputs'
)

Methods

Method Description
train(train_loader, val_loader, **kwargs) Train the model (all 4 phases)
predict(test_loader) Generate predictions
save_student(path) Save student model in HuggingFace format
load_student(path) Load a saved student model
save_predictions(predictions, path, label_mapping) Save predictions to TSV

TrainingConfig

Configure training hyperparameters:

from mpdistil import TrainingConfig

config = TrainingConfig(
    # Phase 1: Teacher
    teacher_epochs=10,
    teacher_lr=2e-5,
    
    # Phase 2: Student PKD
    student_epochs=10,
    student_lr=3e-5,
    alpha=0.5,          # Soft loss weight
    beta=100.0,         # PKD loss weight
    temperature=5.0,    # Distillation temperature
    
    # Phase 3: Meta-Teacher
    meta_epochs=3,      # NEW! Meta-teacher can train for multiple epochs
    meta_lr=1e-3,
    use_competitive_loss=False,  # Use collaborative loss
    
    # Phase 4: Curriculum
    num_episodes=200,
    reward_type='binary',  # or 'real'
    
    # General
    batch_size=8,
    seed=42,
    report_to=None      # Options: 'wandb', 'tensorboard', None
)

history = model.train(train_loader, val_loader, config=config)

Training Parameters

Parameter Type Default Description
teacher_epochs int 10 Phase 1: Teacher training epochs
student_epochs int 10 Phase 2: Student training epochs
meta_epochs int 1 Phase 3: Meta-teacher training epochs (NEW!)
num_episodes int 200 Phase 4: Curriculum learning episodes
teacher_lr float 2e-5 Teacher learning rate
student_lr float 3e-5 Student learning rate
meta_lr float 1e-3 Meta-learning rate
alpha float 0.5 Soft vs hard loss weight
beta float 100.0 PKD loss weight
temperature float 5.0 Distillation temperature
use_competitive_loss bool False Competitive vs collaborative
reward_type str 'binary' 'binary' or 'real'
batch_size int 8 Batch size
seed int 42 Random seed

📊 Results

Performance on SuperGLUE tasks (BERT-base teacher → BERT-6L student):

Model BoolQ CB COPA RTE WiC WSC
Teacher 75.3 83.9 63.0 67.1 57.1 64.4
Student (Undistilled) 71.6 75.0 53.0 64.6 56.0 63.5
MPDistil (Ours) 73.4 83.9 70.0 67.5 59.6 65.4

Student outperforms teacher on 5/6 tasks!

📝 Examples

See the examples/ directory for Jupyter notebooks:

📊 Evaluation Metrics

MPDistil supports multiple evaluation metrics via HuggingFace evaluate library:

Available Metrics

Metric Use Case Returns
'accuracy' Standard classification (default) {'acc': float}
'f1' Imbalanced datasets, multi-class {'acc': float, 'f1': float, 'acc_and_f1': float}
'mcc' Binary classification, imbalanced {'mcc': float}
'correlation' Regression tasks (STS-B) {'pearson': float, 'spearmanr': float}
'auto' Auto-detect based on task Task-specific metric

Example: Using Different Metrics

# Accuracy (default)
model = MPDistil(task_name='BoolQ', num_labels=2, metric='accuracy')

# F1 score (recommended for CB, MultiRC)
model = MPDistil(task_name='CB', num_labels=3, metric='f1')

# Matthews Correlation (for binary with imbalance)
model = MPDistil(task_name='CoLA', num_labels=2, metric='mcc')

# Correlation (for regression)
model = MPDistil(task_name='STS-B', num_labels=1, metric='correlation', is_regression=True)

# Auto-detect
model = MPDistil(task_name='CB', num_labels=3, metric='auto')  # Uses F1 for CB

🔬 How It Works

Phase 1: Teacher Fine-tuning

Fine-tune a large teacher model (e.g., BERT-base) on your target task.

Phase 2: Student PKD

Train a smaller student model using:

  • Soft targets from teacher (KL divergence)
  • Hard labels (cross-entropy)
  • Patient KD (intermediate feature matching)

Loss: α * soft_loss + (1-α) * hard_loss + β * pkd_loss

Phase 3: Meta-Teacher

Create a meta-teacher that learns from both teacher and student representations:

Collaborative loss (default):

L = 0.5 * CE(T'(h_teacher), y) + 0.5 * CE(T'(h_student), y)

Competitive loss:

L = -mean(P_teacher) + mean(P_student) + CE_loss

Phase 4: Curriculum Learning

Use reinforcement learning to select which auxiliary tasks help the student learn:

  • Action model selects next task
  • Reward based on student improvement over teacher
  • REINFORCE algorithm updates policy

🛠️ Advanced Usage

Custom Model Architectures

model = MPDistil(
    task_name='MyTask',
    num_labels=3,
    teacher_model='roberta-large',
    student_model='distilbert-base-uncased',
    student_layers=6
)

Weights & Biases Logging

history = model.train(
    train_loader=train_loader,
    val_loader=val_loader,
    report_to='wandb',      # Options: 'wandb', 'tensorboard', None
    wandb_project='my-project',
    wandb_run_name='experiment-1'
)

Access Trained Models

# Access student model
student = model.student

# Access teacher model
teacher = model.teacher

# Use with HuggingFace
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('./my_student_model')
model = AutoModel.from_pretrained('./my_student_model')

📚 Citation

If you use MPDistil in your research, please cite:

@inproceedings{sengupta2024mpdistil,
  title={A Good Learner can Teach Better: Teacher-Student Collaborative Knowledge Distillation},
  author={Sengupta, Ayan and Dixit, Shantanu and Akhtar, Md Shad and Chakraborty, Tanmoy},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024},
  url={https://openreview.net/forum?id=Ixi4j6LtdX}
}

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

🙏 Acknowledgments

📧 Contact

For questions or issues, please open an issue on GitHub or contact the authors.


Made with ❤️ for the research community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mpdistil-0.1.0.tar.gz (66.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mpdistil-0.1.0-py3-none-any.whl (34.1 kB view details)

Uploaded Python 3

File details

Details for the file mpdistil-0.1.0.tar.gz.

File metadata

  • Download URL: mpdistil-0.1.0.tar.gz
  • Upload date:
  • Size: 66.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for mpdistil-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c7e2cfe12c917046e5c72a95e18e1c1ce3cc36c8ae8346e9e41922ad466020b6
MD5 3a4567ccf37b84b261a131be12bc23a9
BLAKE2b-256 af92b181517327afb442e8de4baddfb5d26fa9cbbdcb461a882c5478b600dece

See more details on using hashes here.

File details

Details for the file mpdistil-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mpdistil-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 34.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for mpdistil-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2fee1b95582f5f55bd04fcc5d74b86d28a0ea68f265042438df21c6a4670c903
MD5 f9dc0506c947cde5f617ca174921c89a
BLAKE2b-256 378370d3e9eea49a48144f8e434d7913383f38cee8c3c595eabefbc935356060

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page