Skip to main content

Calibration-Free Model Compression with Reinforcement Learning-Based Policy Learning

Project description

PruneNet: Calibration-Free Model Compression with Policy Learning

Python 3.8+ PyTorch License: MIT

This repository contains PruneNet, a novel model compression framework that uses reinforcement learning to compress large language models without requiring calibration data.

Based on the paper: You Only Prune Once: Designing Calibration-Free Model Compression With Policy Learning

โœจ Key Features

  • ๐ŸŽฏ No Calibration Data Required - Learns compression policy directly from model weights
  • ๐Ÿค– Reinforcement Learning-Based - Learns optimal neuron selection strategy
  • ๐Ÿ“Š Preserves Spectral Properties - Maintains weight matrix characteristics
  • ๐Ÿš€ Easy to Use - Simple fit() and compress() API following scikit-learn patterns
  • ๐Ÿ”ง Flexible Configuration - Extensive hyperparameter control
  • ๐Ÿ“ฆ Multiple Architectures - Supports OPT, Llama, Phi, Falcon

๐Ÿš€ Quick Start

Installation

git clone https://github.com/parmanu-lcs2/efficient_pruners
cd efficient_pruners
pip install -e .

Basic Usage (New API)

from efficient_pruners import PruneNet, PruningConfig

# Configure hyperparameters
config = PruningConfig(
    num_episodes=20,
    learning_rate=0.001
)

# Initialize pruner
pruner = PruneNet(config)

# Train policy on specific model with target compression ratio
pruner.fit(model_name="facebook/opt-125m", compression_ratio=0.3)

# Compress with the same or different ratio
compressed_model = pruner.compress(compression_ratio=0.3)

# Save compressed model
compressed_model.save_pretrained("./compressed_model")

# Test text generation with compressed LLM
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
inputs = tokenizer("The future of AI is", return_tensors="pt")

# Generate text with compressed model
outputs = compressed_model.generate(**inputs, max_length=50)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(text)

Legacy CLI Usage

The original command-line interface is still available in the prunenet/ directory:

python3 -m prunenet \
    --model_name facebook/opt-125m \
    --compression_ratio 0.3 \
    --save_dir ./models/ \
    --device cuda:0

๐Ÿ“– Documentation

๐Ÿ“‚ Project Structure

PruneNet/
โ”œโ”€โ”€ src/efficient_pruners/     # Main package
โ”‚   โ”œโ”€โ”€ core.py                # PruneNet class (fit/compress API)
โ”‚   โ”œโ”€โ”€ config.py              # PruningConfig dataclass
โ”‚   โ”œโ”€โ”€ models/                # SparsityPredictor policy network
โ”‚   โ”‚   โ””โ”€โ”€ sparsity_predictor.py
โ”‚   โ””โ”€โ”€ utils/                 # Model and reward utilities
โ”‚       โ”œโ”€โ”€ model_utils.py
โ”‚       โ””โ”€โ”€ reward_utils.py
โ”œโ”€โ”€ examples/                  # Test & usage examples
โ”‚   โ””โ”€โ”€ test_fit_compress.py   # Complete fit/compress test script
โ”œโ”€โ”€ notebooks/                 # Interactive tutorials
โ”‚   โ””โ”€โ”€ test_fit_compress.ipynb  # Complete fit/compress test notebook
โ”œโ”€โ”€ docs/                      # Documentation
โ”‚   โ””โ”€โ”€ API_GUIDE.md
โ”œโ”€โ”€ prunenet/                  # Original CLI implementation
โ”œโ”€โ”€ setup.py                   # Package setup
โ”œโ”€โ”€ pyproject.toml             # Modern build system
โ””โ”€โ”€ requirements.txt           # Dependencies

๐ŸŽฏ Supported Models

  • OPT: facebook/opt-125m, facebook/opt-1.3b, etc.
  • Llama: meta-llama/Llama-2-7b-hf, etc.
  • Phi: microsoft/phi-1, microsoft/phi-2, etc.
  • Falcon: tiiuae/falcon-7b, etc.

๐Ÿงช Running Examples

Test Script

Run the comprehensive test to verify both fit() and compress() methods:

python examples/test_fit_compress.py

This script will:

  • โœ… Train an RL policy using fit()
  • โœ… Compress the model using compress()
  • โœ… Test .generate() on the compressed LLM
  • โœ… Compare outputs between original and compressed models
  • โœ… Display compression statistics

Interactive Notebook

jupyter notebook notebooks/test_fit_compress.ipynb

The notebook includes:

  • Step-by-step walkthrough of fit() and compress()
  • Visualizations of training progress
  • Interactive text generation testing with compressed model
  • Side-by-side comparison of model outputs

โš™๏ธ Advanced Configuration

config = PruningConfig(
    num_episodes=20,
    learning_rate=0.001,
    use_kld=True,          # Enable KL divergence regularization
    gamma=0.99,            # Reward discount factor
    device="auto",         # Auto-detect GPU/CPU
    save_dir="./outputs"   # Checkpoint directory
)

pruner = PruneNet(config)
pruner.fit(model_name="facebook/opt-125m")
compressed_model = pruner.compress(compression_ratio=0.3)

See API_GUIDE.md for all configuration options.

๐Ÿ“Š Performance

Typical compression results on OPT-125M:

Compression Size Reduction Perplexity Impact
20% ~15% +2-3%
30% ~22% +3-5%
40% ~30% +5-8%
50% ~37% +8-12%

๐Ÿ”ฌ Research & Original Implementation

The original research scripts are preserved in prunenet/ and experiments/ directories. See the original README sections below for research-specific details.


Original Evaluation Scripts

Slicing the attention modules

Citation

If you find our work useful in your projects/research, kindly cite our paper:

@inproceedings{
    sengupta2025you,
    title={You Only Prune Once: Designing Calibration-Free Model Compression With Policy Learning},
    author={Ayan Sengupta and Siddhant Chaudhary and Tanmoy Chakraborty},
    booktitle={The Thirteenth International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/forum?id=5RZoYIT3u6}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

efficient_pruners-0.1.0.tar.gz (46.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

efficient_pruners-0.1.0-py3-none-any.whl (45.5 kB view details)

Uploaded Python 3

File details

Details for the file efficient_pruners-0.1.0.tar.gz.

File metadata

  • Download URL: efficient_pruners-0.1.0.tar.gz
  • Upload date:
  • Size: 46.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for efficient_pruners-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9ab5048e1edd659fa05140350634bdca10a1c732c2dc6dc345d0304cc3308efb
MD5 f65a958ed70c3dd3a42dd80436e2f537
BLAKE2b-256 29a321fc795e27fb3cca3f3f85241f6d5f91c9bfa5c35e29cd52db7a6b797c4e

See more details on using hashes here.

File details

Details for the file efficient_pruners-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for efficient_pruners-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 23d2263d8a39c08db37b515ee495d15f34521e35beb470a0132feb6ea8468453
MD5 5640bfa20a2aea2a63a2e23b761a6cf4
BLAKE2b-256 83e66710b1dcace057b85a00bf18d7d045c410e3cddf7b1ac77eea961a735b53

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page