BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform
Project description
BitNet v3: Ultra-Low Quality Loss 1-bit LLMs
A comprehensive PyTorch implementation of BitNet v3, a novel framework for training 1-bit Large Language Models (LLMs) that significantly reduces quality loss while maintaining computational efficiency benefits of extreme quantization.
🚀 Key Features
BitNet v3 introduces five key innovations that reduce quality degradation to just 0.3% while maintaining 4.5x speedup and 85% memory reduction:
- 🔄 Multi-stage Progressive Quantization (MPQ) - Gradually reduces bit-width during training
- 🧮 Adaptive Hadamard Transform with Learnable Parameters (AHT-LP) - Dynamically adjusts to activation distributions
- 🎓 Gradient-Aware Knowledge Distillation (GAKD) - Preserves critical gradient information during quantization
- ⚖️ Dynamic Regularization with Quantization-Aware Penalties (DR-QAP) - Stabilizes training with adaptive penalties
- 💫 Enhanced Straight-Through Estimator with Momentum (ESTE-M) - Improves gradient approximation
📊 Performance Results
| Model Size | Method | Perplexity | Quality Loss | Speedup | Memory |
|---|---|---|---|---|---|
| 1.3B | BitNet v3 | 14.58 | +0.4% | 4.5x | 15% |
| 3B | BitNet v3 | 11.87 | +0.3% | 4.5x | 15% |
| 7B | BitNet v3 | 9.45 | +0.3% | 4.5x | 15% |
Compared to full-precision FP16 models
🛠️ Installation
From PyPI (Recommended)
pip install bitnet-v3
From Source
git clone https://github.com/ProCreations/bitnet-v3.git
cd bitnet-v3
pip install -e .
Development Installation
git clone https://github.com/ProCreations/bitnet-v3.git
cd bitnet-v3
pip install -e ".[dev]"
🎯 Quick Start
Simple Usage
import bitnet_v3
# Create a BitNet v3 model
model = bitnet_v3.create_model(
vocab_size=32000,
hidden_size=2048,
num_layers=24,
num_heads=32,
)
# Create trainer with MPQ schedule
trainer = bitnet_v3.create_trainer(
model,
learning_rate=3e-4,
batch_size=256,
enable_mpq=True,
enable_gakd=True,
)
# Train the model
trainer.train(train_dataloader)
Advanced Usage with All Features
import torch
import bitnet_v3
# Configure model with all innovations
config = bitnet_v3.BitNetV3Config(
vocab_size=32000,
hidden_size=4096,
num_layers=32,
num_heads=32,
# MPQ configuration
mpq_stages=[
{"epochs": 20, "bits": 8},
{"epochs": 20, "bits": 4},
{"epochs": 15, "bits": 2},
{"epochs": 15, "bits": 1.58},
],
# AHT-LP configuration
adaptive_hadamard=True,
hadamard_learnable_scale=True,
# GAKD configuration
knowledge_distillation=True,
gakd_alpha=0.7,
gakd_beta=0.2,
gakd_gamma=0.1,
# DR-QAP configuration
dynamic_regularization=True,
qap_initial_lambda=0.1,
# ESTE-M configuration
enhanced_ste=True,
ste_momentum=0.9,
)
# Create model and trainer
model = bitnet_v3.BitNetV3Model(config)
trainer = bitnet_v3.BitNetV3Trainer(model, config)
# Load teacher model for knowledge distillation
teacher_model = torch.load("teacher_model.pth")
trainer.set_teacher_model(teacher_model)
# Train with all features
trainer.train(
train_dataloader,
val_dataloader,
num_epochs=70,
save_every=5,
eval_every=1,
)
🏗️ Architecture Overview
Core Components
bitnet_v3.core- Core quantization functions and utilitiesbitnet_v3.modules- Individual innovation modules (MPQ, AHT-LP, GAKD, etc.)bitnet_v3.models- Complete BitNet v3 model implementationsbitnet_v3.training- Training pipeline and utilitiesbitnet_v3.utils- Configuration, logging, and metrics
Key Modules
# Enhanced H-BitLinear with all innovations
linear_layer = bitnet_v3.EnhancedHBitLinear(
in_features=2048,
out_features=2048,
bias=False,
adaptive_hadamard=True,
progressive_quantization=True,
)
# Multi-stage Progressive Quantizer
mpq = bitnet_v3.MultiStageProgressiveQuantizer(
stages=[8, 4, 2, 1.58],
stage_epochs=[20, 20, 15, 15],
)
# Adaptive Hadamard Transform
aht = bitnet_v3.AdaptiveHadamardTransform(
size=2048,
learnable_params=True,
)
# Gradient-Aware Knowledge Distillation
gakd = bitnet_v3.GradientAwareKnowledgeDistillation(
alpha=0.7, # KL divergence weight
beta=0.2, # Gradient alignment weight
gamma=0.1, # Feature alignment weight
)
📚 Detailed Documentation
Multi-Stage Progressive Quantization (MPQ)
MPQ gradually reduces bit-width during training, allowing models to adapt smoothly:
# Configure MPQ stages
mpq_config = {
"stages": [
{"start_epoch": 1, "end_epoch": 20, "bits": 8},
{"start_epoch": 21, "end_epoch": 40, "bits": 4},
{"start_epoch": 41, "end_epoch": 55, "bits": 2},
{"start_epoch": 56, "end_epoch": 70, "bits": 1.58},
],
"temperature_schedule": "linear", # or "cosine"
}
scheduler = bitnet_v3.MPQScheduler(**mpq_config)
Adaptive Hadamard Transform (AHT-LP)
Enhanced Hadamard transformation with learnable parameters:
# Standard Hadamard transform
x_transformed = bitnet_v3.hadamard_transform(x)
# Adaptive Hadamard with learnable parameters
aht = bitnet_v3.AdaptiveHadamardTransform(
size=x.size(-1),
learnable_scale=True,
learnable_shift=True,
)
x_adaptive = aht(x)
Gradient-Aware Knowledge Distillation (GAKD)
Preserves gradient information during distillation:
# Set up GAKD
gakd_loss = bitnet_v3.GradientAwareKnowledgeDistillation(
alpha=0.7, # Output distribution weight
beta=0.2, # Gradient alignment weight
gamma=0.1, # Feature alignment weight
)
# Compute distillation loss
loss = gakd_loss(
student_outputs,
teacher_outputs,
student_features,
teacher_features,
student_gradients,
teacher_gradients,
)
🧪 Examples
Training from Scratch
import bitnet_v3
from torch.utils.data import DataLoader
# Load your dataset
train_dataset = YourDataset("train")
train_loader = DataLoader(train_dataset, batch_size=256)
# Create model with default config
model = bitnet_v3.create_model(
vocab_size=len(tokenizer),
hidden_size=2048,
num_layers=24,
)
# Train with MPQ
trainer = bitnet_v3.create_trainer(model)
trainer.train(train_loader, num_epochs=70)
Fine-tuning Pre-trained Model
# Load pre-trained model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/model")
# Convert to BitNet v3 with progressive quantization
bitnet_model = bitnet_v3.convert_to_bitnet_v3(
model,
enable_all_features=True,
)
# Fine-tune with knowledge distillation
trainer = bitnet_v3.create_trainer(bitnet_model)
trainer.set_teacher_model(model) # Use original as teacher
trainer.train(fine_tune_loader, num_epochs=20)
Inference
# Load trained BitNet v3 model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/bitnet_v3_model")
# Generate text
output = model.generate(
input_ids,
max_length=100,
temperature=0.7,
do_sample=True,
)
🔬 Research Paper Implementation
This implementation includes all techniques from the original BitNet v3 research paper:
Quantization Functions
- Ternary weight quantization:
{-1, 0, 1} - 4-bit activation quantization with Hadamard transform
- AbsMean and AbsMax quantization schemes
Training Innovations
- Progressive bit-width reduction schedule
- Temperature-based quantization transitions
- Gradient-aware loss computation
- Dynamic regularization with layer sensitivity
Mathematical Formulations
All key equations from the paper are implemented:
# Temperature-based transition (Equation 1)
Q_t(x) = σ(β_t) * Q_b_t(x) + (1 - σ(β_t)) * Q_b_{t-1}(x)
# Adaptive Hadamard transform (Equation 2)
H_adaptive(x) = γ ⊙ (H_m · x) + β
# GAKD loss (Equation 3)
L_GAKD = α*L_KL + β*L_grad + γ*L_feature
# Dynamic regularization (Equation 4)
R_QAP = λ(t) * Σ ω_i ||W_i - Q(W_i)||²
📈 Evaluation and Metrics
Built-in evaluation tools for comprehensive analysis:
# Compute perplexity
ppl = bitnet_v3.compute_perplexity(model, test_loader)
# Efficiency metrics
metrics = bitnet_v3.compute_efficiency_metrics(
bitnet_model,
baseline_model,
test_input,
)
print(f"Speedup: {metrics['speedup']:.1f}x")
print(f"Memory reduction: {metrics['memory_reduction']:.1f}%")
# Downstream task evaluation
results = bitnet_v3.evaluate_downstream_tasks(
model,
tasks=["hellaswag", "mmlu", "truthfulqa"],
)
🛡️ Testing
Run the comprehensive test suite:
# Run all tests
pytest tests/
# Run specific test modules
pytest tests/test_modules/test_mpq.py
pytest tests/test_modules/test_gakd.py
# Run with coverage
pytest --cov=bitnet_v3 tests/
🤝 Contributing
We welcome contributions! Please see our Contributing Guide for details.
Development Setup
git clone https://github.com/ProCreations/bitnet-v3.git
cd bitnet-v3
pip install -e ".[dev]"
pre-commit install
Running Tests
pytest tests/
black bitnet_v3/
isort bitnet_v3/
flake8 bitnet_v3/
mypy bitnet_v3/
📄 Citation
If you use BitNet v3 in your research, please cite:
@article{bitnet_v3_2024,
title={BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform},
author={ProCreations},
journal={arXiv preprint arXiv:XXXX.XXXXX},
year={2024}
}
📜 License
This project is licensed under the MIT License - see the LICENSE file for details.
🙏 Acknowledgments
- Built upon the foundation of BitNet and BitNet b1.58 from Microsoft Research
- Inspired by advances in quantization-aware training and knowledge distillation
- Thanks to the PyTorch team for the excellent deep learning framework
📞 Support
- 🐛 Bug Reports: GitHub Issues
- 💬 Discussions: GitHub Discussions
- 📧 Email: procreations@example.com
BitNet v3 - Bringing 1-bit LLMs closer to practical deployment! 🚀
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bitnet_v3-1.0.0.tar.gz.
File metadata
- Download URL: bitnet_v3-1.0.0.tar.gz
- Upload date:
- Size: 58.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7973eaa0ea60aef1d452dbd652149eed64135c2a7a0a02efa2bc1735caf154ef
|
|
| MD5 |
2e3e5f20e77b3d07f1fc1177855bdbea
|
|
| BLAKE2b-256 |
c329a5061e543729222c8fe704cddf90e9f61584893f2356bcb89c9f851e8a36
|
File details
Details for the file bitnet_v3-1.0.0-py3-none-any.whl.
File metadata
- Download URL: bitnet_v3-1.0.0-py3-none-any.whl
- Upload date:
- Size: 56.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bff392288dcc5f086228eefb2c7339ff6e78b8798e1de7824893bb05681221c9
|
|
| MD5 |
590a60cee6578eec2af73da504ad92d0
|
|
| BLAKE2b-256 |
555584183518cba9dcfa33e3a1f3d56ee7e1b3eb450a7f0f71b895b89c5bc565
|