Calibration-Free Model Compression with Reinforcement Learning-Based Policy Learning
Project description
PruneNet: Calibration-Free Model Compression with Policy Learning
This repository contains PruneNet, a novel model compression framework that uses reinforcement learning to compress large language models without requiring calibration data.
Based on the paper: You Only Prune Once: Designing Calibration-Free Model Compression With Policy Learning
โจ Key Features
- ๐ฏ No Calibration Data Required - Learns compression policy directly from model weights
- ๐ค Reinforcement Learning-Based - Learns optimal neuron selection strategy
- ๐ Preserves Spectral Properties - Maintains weight matrix characteristics
- ๐ Easy to Use - Simple
fit()andcompress()API following scikit-learn patterns - ๐ง Flexible Configuration - Extensive hyperparameter control
- ๐ฆ Multiple Architectures - Supports OPT, Llama, Phi, Falcon
๐ Quick Start
Installation
git clone https://github.com/parmanu-lcs2/efficient_pruners
cd efficient_pruners
pip install -e .
Basic Usage (New API)
from efficient_pruners import PruneNet, PruningConfig
# Configure hyperparameters
config = PruningConfig(
num_episodes=20,
learning_rate=0.001
)
# Initialize pruner
pruner = PruneNet(config)
# Train policy on specific model with target compression ratio
pruner.fit(model_name="facebook/opt-125m", compression_ratio=0.3)
# Compress with the same or different ratio
compressed_model = pruner.compress(compression_ratio=0.3)
# Save compressed model
compressed_model.save_pretrained("./compressed_model")
# Test text generation with compressed LLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
inputs = tokenizer("The future of AI is", return_tensors="pt")
# Generate text with compressed model
outputs = compressed_model.generate(**inputs, max_length=50)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(text)
Legacy CLI Usage
The original command-line interface is still available in the prunenet/ directory:
python3 -m prunenet \
--model_name facebook/opt-125m \
--compression_ratio 0.3 \
--save_dir ./models/ \
--device cuda:0
๐ Documentation
- API Guide - Complete API reference
- Test Notebook - Interactive fit/compress test with visualizations
- Test Script - Automated fit/compress test
๐ Project Structure
PruneNet/
โโโ src/efficient_pruners/ # Main package
โ โโโ core.py # PruneNet class (fit/compress API)
โ โโโ config.py # PruningConfig dataclass
โ โโโ models/ # SparsityPredictor policy network
โ โ โโโ sparsity_predictor.py
โ โโโ utils/ # Model and reward utilities
โ โโโ model_utils.py
โ โโโ reward_utils.py
โโโ examples/ # Test & usage examples
โ โโโ test_fit_compress.py # Complete fit/compress test script
โโโ notebooks/ # Interactive tutorials
โ โโโ test_fit_compress.ipynb # Complete fit/compress test notebook
โโโ docs/ # Documentation
โ โโโ API_GUIDE.md
โโโ prunenet/ # Original CLI implementation
โโโ setup.py # Package setup
โโโ pyproject.toml # Modern build system
โโโ requirements.txt # Dependencies
๐ฏ Supported Models
- OPT: facebook/opt-125m, facebook/opt-1.3b, etc.
- Llama: meta-llama/Llama-2-7b-hf, etc.
- Phi: microsoft/phi-1, microsoft/phi-2, etc.
- Falcon: tiiuae/falcon-7b, etc.
๐งช Running Examples
Test Script
Run the comprehensive test to verify both fit() and compress() methods:
python examples/test_fit_compress.py
This script will:
- โ
Train an RL policy using
fit() - โ
Compress the model using
compress() - โ
Test
.generate()on the compressed LLM - โ Compare outputs between original and compressed models
- โ Display compression statistics
Interactive Notebook
jupyter notebook notebooks/test_fit_compress.ipynb
The notebook includes:
- Step-by-step walkthrough of
fit()andcompress() - Visualizations of training progress
- Interactive text generation testing with compressed model
- Side-by-side comparison of model outputs
โ๏ธ Advanced Configuration
config = PruningConfig(
num_episodes=20,
learning_rate=0.001,
use_kld=True, # Enable KL divergence regularization
gamma=0.99, # Reward discount factor
device="auto", # Auto-detect GPU/CPU
save_dir="./outputs" # Checkpoint directory
)
pruner = PruneNet(config)
pruner.fit(model_name="facebook/opt-125m")
compressed_model = pruner.compress(compression_ratio=0.3)
See API_GUIDE.md for all configuration options.
๐ Performance
Typical compression results on OPT-125M:
| Compression | Size Reduction | Perplexity Impact |
|---|---|---|
| 20% | ~15% | +2-3% |
| 30% | ~22% | +3-5% |
| 40% | ~30% | +5-8% |
| 50% | ~37% | +8-12% |
๐ฌ Research & Original Implementation
The original research scripts are preserved in prunenet/ and experiments/ directories. See the original README sections below for research-specific details.
Original Evaluation Scripts
Slicing the attention modules
Citation
If you find our work useful in your projects/research, kindly cite our paper:
@inproceedings{
sengupta2025you,
title={You Only Prune Once: Designing Calibration-Free Model Compression With Policy Learning},
author={Ayan Sengupta and Siddhant Chaudhary and Tanmoy Chakraborty},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=5RZoYIT3u6}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file efficient_pruners-0.1.0.tar.gz.
File metadata
- Download URL: efficient_pruners-0.1.0.tar.gz
- Upload date:
- Size: 46.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9ab5048e1edd659fa05140350634bdca10a1c732c2dc6dc345d0304cc3308efb
|
|
| MD5 |
f65a958ed70c3dd3a42dd80436e2f537
|
|
| BLAKE2b-256 |
29a321fc795e27fb3cca3f3f85241f6d5f91c9bfa5c35e29cd52db7a6b797c4e
|
File details
Details for the file efficient_pruners-0.1.0-py3-none-any.whl.
File metadata
- Download URL: efficient_pruners-0.1.0-py3-none-any.whl
- Upload date:
- Size: 45.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
23d2263d8a39c08db37b515ee495d15f34521e35beb470a0132feb6ea8468453
|
|
| MD5 |
5640bfa20a2aea2a63a2e23b761a6cf4
|
|
| BLAKE2b-256 |
83e66710b1dcace057b85a00bf18d7d045c410e3cddf7b1ac77eea961a735b53
|