BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform
Project description
BitNet v3: Ultra-Low Quality Loss 1-bit LLMs
A comprehensive PyTorch implementation of BitNet v3, a novel framework for training 1-bit Large Language Models (LLMs) that significantly reduces quality loss while maintaining computational efficiency benefits of extreme quantization.
🚀 Key Features
BitNet v3 introduces five key innovations for ultra-low quality loss 1-bit LLMs:
- 🔄 Multi-stage Progressive Quantization (MPQ) - Gradually reduces bit-width during training
- 🧮 Adaptive Hadamard Transform with Learnable Parameters (AHT-LP) - Dynamically adjusts to activation distributions
- 🎓 Gradient-Aware Knowledge Distillation (GAKD) - Preserves critical gradient information during quantization
- ⚖️ Dynamic Regularization with Quantization-Aware Penalties (DR-QAP) - Stabilizes training with adaptive penalties
- 💫 Enhanced Straight-Through Estimator with Momentum (ESTE-M) - Improves gradient approximation
🔬 Research Status
This implementation provides the framework for training 1-bit LLMs with the potential for significant quality improvements over existing methods. Performance evaluation is ongoing - we're actively seeking contributors to help with testing, benchmarking, and validation across different model sizes and datasets.
🚨 CONTRIBUTORS WANTED! 🚨
Help us validate BitNet v3! We need researchers and engineers to test performance, optimize code, and validate results. All skill levels welcome - from bug reports to research contributions. Jump to Contributing section or start a discussion!
🛠️ Installation
From PyPI (Recommended)
pip install bitnet-v3
From Source
git clone https://github.com/ProCreations-Official/bitnet-v3.git
cd bitnet-v3
pip install -e .
Development Installation
git clone https://github.com/ProCreations-Official/bitnet-v3.git
cd bitnet-v3
pip install -e ".[dev]"
🎯 Quick Start
Simple Usage
import bitnet_v3
# Create a BitNet v3 model
model = bitnet_v3.create_model(
vocab_size=32000,
hidden_size=2048,
num_layers=24,
num_heads=32,
)
# Create trainer with MPQ schedule
trainer = bitnet_v3.create_trainer(
model,
learning_rate=3e-4,
batch_size=256,
enable_mpq=True,
enable_gakd=True,
)
# Train the model
trainer.train(train_dataloader)
Advanced Usage with All Features
import torch
import bitnet_v3
# Configure model with all innovations
config = bitnet_v3.BitNetV3Config(
vocab_size=32000,
hidden_size=4096,
num_layers=32,
num_heads=32,
# MPQ configuration
mpq_stages=[
{"epochs": 20, "bits": 8},
{"epochs": 20, "bits": 4},
{"epochs": 15, "bits": 2},
{"epochs": 15, "bits": 1.58},
],
# AHT-LP configuration
adaptive_hadamard=True,
hadamard_learnable_scale=True,
# GAKD configuration
knowledge_distillation=True,
gakd_alpha=0.7,
gakd_beta=0.2,
gakd_gamma=0.1,
# DR-QAP configuration
dynamic_regularization=True,
qap_initial_lambda=0.1,
# ESTE-M configuration
enhanced_ste=True,
ste_momentum=0.9,
)
# Create model and trainer
model = bitnet_v3.BitNetV3Model(config)
trainer = bitnet_v3.BitNetV3Trainer(model, config)
# Load teacher model for knowledge distillation
teacher_model = torch.load("teacher_model.pth")
trainer.set_teacher_model(teacher_model)
# Train with all features
trainer.train(
train_dataloader,
val_dataloader,
num_epochs=70,
save_every=5,
eval_every=1,
)
🏗️ Architecture Overview
Core Components
bitnet_v3.core- Core quantization functions and utilitiesbitnet_v3.modules- Individual innovation modules (MPQ, AHT-LP, GAKD, etc.)bitnet_v3.models- Complete BitNet v3 model implementationsbitnet_v3.training- Training pipeline and utilitiesbitnet_v3.utils- Configuration, logging, and metrics
Key Modules
# Enhanced H-BitLinear with all innovations
linear_layer = bitnet_v3.EnhancedHBitLinear(
in_features=2048,
out_features=2048,
bias=False,
adaptive_hadamard=True,
progressive_quantization=True,
)
# Multi-stage Progressive Quantizer
mpq = bitnet_v3.MultiStageProgressiveQuantizer(
stages=[8, 4, 2, 1.58],
stage_epochs=[20, 20, 15, 15],
)
# Adaptive Hadamard Transform
aht = bitnet_v3.AdaptiveHadamardTransform(
size=2048,
learnable_params=True,
)
# Gradient-Aware Knowledge Distillation
gakd = bitnet_v3.GradientAwareKnowledgeDistillation(
alpha=0.7, # KL divergence weight
beta=0.2, # Gradient alignment weight
gamma=0.1, # Feature alignment weight
)
📚 Detailed Documentation
Multi-Stage Progressive Quantization (MPQ)
MPQ gradually reduces bit-width during training, allowing models to adapt smoothly:
# Configure MPQ stages
mpq_config = {
"stages": [
{"start_epoch": 1, "end_epoch": 20, "bits": 8},
{"start_epoch": 21, "end_epoch": 40, "bits": 4},
{"start_epoch": 41, "end_epoch": 55, "bits": 2},
{"start_epoch": 56, "end_epoch": 70, "bits": 1.58},
],
"temperature_schedule": "linear", # or "cosine"
}
scheduler = bitnet_v3.MPQScheduler(**mpq_config)
Adaptive Hadamard Transform (AHT-LP)
Enhanced Hadamard transformation with learnable parameters:
# Standard Hadamard transform
x_transformed = bitnet_v3.hadamard_transform(x)
# Adaptive Hadamard with learnable parameters
aht = bitnet_v3.AdaptiveHadamardTransform(
size=x.size(-1),
learnable_scale=True,
learnable_shift=True,
)
x_adaptive = aht(x)
Gradient-Aware Knowledge Distillation (GAKD)
Preserves gradient information during distillation:
# Set up GAKD
gakd_loss = bitnet_v3.GradientAwareKnowledgeDistillation(
alpha=0.7, # Output distribution weight
beta=0.2, # Gradient alignment weight
gamma=0.1, # Feature alignment weight
)
# Compute distillation loss
loss = gakd_loss(
student_outputs,
teacher_outputs,
student_features,
teacher_features,
student_gradients,
teacher_gradients,
)
🧪 Examples
Training from Scratch
import bitnet_v3
from torch.utils.data import DataLoader
# Load your dataset
train_dataset = YourDataset("train")
train_loader = DataLoader(train_dataset, batch_size=256)
# Create model with default config
model = bitnet_v3.create_model(
vocab_size=len(tokenizer),
hidden_size=2048,
num_layers=24,
)
# Train with MPQ
trainer = bitnet_v3.create_trainer(model)
trainer.train(train_loader, num_epochs=70)
Fine-tuning Pre-trained Model
# Load pre-trained model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/model")
# Convert to BitNet v3 with progressive quantization
bitnet_model = bitnet_v3.convert_to_bitnet_v3(
model,
enable_all_features=True,
)
# Fine-tune with knowledge distillation
trainer = bitnet_v3.create_trainer(bitnet_model)
trainer.set_teacher_model(model) # Use original as teacher
trainer.train(fine_tune_loader, num_epochs=20)
Inference
# Load trained BitNet v3 model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/bitnet_v3_model")
# Generate text
output = model.generate(
input_ids,
max_length=100,
temperature=0.7,
do_sample=True,
)
🔬 Research Paper Implementation
This implementation includes all techniques from the original BitNet v3 research paper:
Quantization Functions
- Ternary weight quantization:
{-1, 0, 1} - 4-bit activation quantization with Hadamard transform
- AbsMean and AbsMax quantization schemes
Training Innovations
- Progressive bit-width reduction schedule
- Temperature-based quantization transitions
- Gradient-aware loss computation
- Dynamic regularization with layer sensitivity
Mathematical Formulations
All key equations from the paper are implemented:
# Temperature-based transition (Equation 1)
Q_t(x) = σ(β_t) * Q_b_t(x) + (1 - σ(β_t)) * Q_b_{t-1}(x)
# Adaptive Hadamard transform (Equation 2)
H_adaptive(x) = γ ⊙ (H_m · x) + β
# GAKD loss (Equation 3)
L_GAKD = α*L_KL + β*L_grad + γ*L_feature
# Dynamic regularization (Equation 4)
R_QAP = λ(t) * Σ ω_i ||W_i - Q(W_i)||²
📈 Evaluation and Metrics
Built-in evaluation tools for comprehensive analysis:
# Compute perplexity
ppl = bitnet_v3.compute_perplexity(model, test_loader)
# Efficiency metrics (theoretical analysis)
metrics = bitnet_v3.compute_efficiency_metrics(
bitnet_model,
baseline_model,
test_input,
)
print(f"Expected speedup: {metrics['speedup']:.1f}x")
print(f"Expected memory reduction: {metrics['memory_reduction']:.1f}%")
# Downstream task evaluation framework
results = bitnet_v3.evaluate_downstream_tasks(
model,
tasks=["hellaswag", "mmlu", "truthfulqa"],
)
Note: Performance validation is ongoing. We encourage the community to help benchmark BitNet v3 across different tasks and model sizes.
🛡️ Testing
Run the comprehensive test suite:
# Run all tests
pytest tests/
# Run specific test modules
pytest tests/test_modules/test_mpq.py
pytest tests/test_modules/test_gakd.py
# Run with coverage
pytest --cov=bitnet_v3 tests/
🤝 Contributing - We Need Your Help!
BitNet v3 is an active research project and we're actively seeking contributors! Whether you're a researcher, engineer, or enthusiast, there are many ways to contribute to advancing 1-bit LLM technology.
🚨 High Priority Contributions Needed
- 🧪 Performance Benchmarking: Help us validate BitNet v3 across different model sizes (1B, 3B, 7B+)
- 📊 Dataset Testing: Test on various datasets (language modeling, downstream tasks, multilingual)
- ⚡ Optimization: CUDA kernels, memory optimizations, training speed improvements
- 🔧 Integration: HuggingFace Transformers integration, ONNX export, deployment tools
- 📝 Documentation: Tutorials, guides, and improved examples
- 🐛 Bug Reports: Help us identify and fix issues in the codebase
🎯 Research Opportunities
- Compare against other quantization methods (GPTQ, AWQ, etc.)
- Explore different MPQ schedules and temperature functions
- Investigate GAKD effectiveness across model architectures
- Test on specialized domains (code, math, science)
- Efficiency analysis on different hardware (GPUs, edge devices)
💡 Easy Ways to Get Started
- Run the examples and report any issues
- Test installation on different systems (Windows, Mac, Linux)
- Improve documentation - add docstrings, fix typos, clarify explanations
- Add unit tests for untested modules
- Create tutorials for specific use cases
🛠 Development Setup
# Fork the repository on GitHub first!
git clone https://github.com/ProCreations-Official/bitnet-v3.git
cd bitnet-v3
# Install in development mode
pip install -e ".[dev]"
# Set up pre-commit hooks
pre-commit install
# Run tests to ensure everything works
pytest tests/
📋 Contribution Guidelines
- All skill levels welcome - from typo fixes to major algorithmic improvements
- Research-first approach - we prioritize correctness and reproducibility
- Open communication - discuss ideas in GitHub Issues before major changes
- Documentation required - all new features need documentation and examples
- Testing encouraged - add tests for new functionality when possible
🌟 Recognition
Contributors will be:
- Added to the contributors list in the README
- Acknowledged in any resulting research papers
- Invited to collaborate on follow-up research
📞 Get in Touch
- 💬 Start a Discussion: GitHub Discussions for questions and ideas
- 🐛 Report Issues: GitHub Issues for bugs and feature requests
Every contribution matters - from fixing a typo to implementing a new feature. Join us in making 1-bit LLMs a reality! 🚀
📄 Citation
If you use BitNet v3 in your research, please cite this repository:
@software{bitnet_v3_2024,
title={BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform},
author={ProCreations},
url={https://github.com/ProCreations-Official/bitnet-v3},
year={2024}
}
📜 License
This project is licensed under the MIT License - see the LICENSE file for details.
🙏 Acknowledgments
- Built upon the foundation of BitNet and BitNet b1.58 from Microsoft Research
- Inspired by advances in quantization-aware training and knowledge distillation
- Thanks to the PyTorch team for the excellent deep learning framework
📞 Support
- 🐛 Bug Reports: GitHub Issues
- 💬 Discussions: GitHub Discussions
BitNet v3 - Bringing 1-bit LLMs closer to practical deployment! 🚀
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bitnet_v3-1.0.1.tar.gz.
File metadata
- Download URL: bitnet_v3-1.0.1.tar.gz
- Upload date:
- Size: 60.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a44378fbdd0a4098120fe98e8f80d87d45e1371096cd08ee06e07fa843bcaf1a
|
|
| MD5 |
9b94f74ee3876d5a148bc9e5d06b74a6
|
|
| BLAKE2b-256 |
b6abbefead7ea73a138905d88863b8ea79199cb6d68404bb7003161f6b430a0d
|
File details
Details for the file bitnet_v3-1.0.1-py3-none-any.whl.
File metadata
- Download URL: bitnet_v3-1.0.1-py3-none-any.whl
- Upload date:
- Size: 57.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
98f97acf2b2fa3cc776c5416294217c6813078e14d2986558758d7a48733b654
|
|
| MD5 |
fe87bd4870c4729034f94c7c46c3dac0
|
|
| BLAKE2b-256 |
efa608cf118d40084b19df0ad81728527e2b5e4b378b81340671f7cf706811fd
|