Skip to main content

BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform

Project description

BitNet v3: Ultra-Low Quality Loss 1-bit LLMs

PyPI version Python 3.8+ License: MIT

A comprehensive PyTorch implementation of BitNet v3, a novel framework for training 1-bit Large Language Models (LLMs) that significantly reduces quality loss while maintaining computational efficiency benefits of extreme quantization.

🚀 Key Features

BitNet v3 introduces five key innovations for ultra-low quality loss 1-bit LLMs:

  1. 🔄 Multi-stage Progressive Quantization (MPQ) - Gradually reduces bit-width during training
  2. 🧮 Adaptive Hadamard Transform with Learnable Parameters (AHT-LP) - Dynamically adjusts to activation distributions
  3. 🎓 Gradient-Aware Knowledge Distillation (GAKD) - Preserves critical gradient information during quantization
  4. ⚖️ Dynamic Regularization with Quantization-Aware Penalties (DR-QAP) - Stabilizes training with adaptive penalties
  5. 💫 Enhanced Straight-Through Estimator with Momentum (ESTE-M) - Improves gradient approximation

🔬 Research Status

This implementation provides the framework for training 1-bit LLMs with the potential for significant quality improvements over existing methods. Performance evaluation is ongoing - we're actively seeking contributors to help with testing, benchmarking, and validation across different model sizes and datasets.


🚨 CONTRIBUTORS WANTED! 🚨
Help us validate BitNet v3! We need researchers and engineers to test performance, optimize code, and validate results. All skill levels welcome - from bug reports to research contributions. Jump to Contributing section or start a discussion!


🛠️ Installation

From PyPI (Recommended)

pip install bitnet-v3

From Source

git clone https://github.com/ProCreations-Official/bitnet-v3.git
cd bitnet-v3
pip install -e .

Development Installation

git clone https://github.com/ProCreations-Official/bitnet-v3.git
cd bitnet-v3
pip install -e ".[dev]"

🎯 Quick Start

Simple Usage

import bitnet_v3

# Create a BitNet v3 model
model = bitnet_v3.create_model(
    vocab_size=32000,
    hidden_size=2048,
    num_layers=24,
    num_heads=32,
)

# Create trainer with MPQ schedule
trainer = bitnet_v3.create_trainer(
    model,
    learning_rate=3e-4,
    batch_size=256,
    enable_mpq=True,
    enable_gakd=True,
)

# Train the model
trainer.train(train_dataloader)

Advanced Usage with All Features

import torch
import bitnet_v3

# Configure model with all innovations
config = bitnet_v3.BitNetV3Config(
    vocab_size=32000,
    hidden_size=4096,
    num_layers=32,
    num_heads=32,
    # MPQ configuration
    mpq_stages=[
        {"epochs": 20, "bits": 8},
        {"epochs": 20, "bits": 4}, 
        {"epochs": 15, "bits": 2},
        {"epochs": 15, "bits": 1.58},
    ],
    # AHT-LP configuration
    adaptive_hadamard=True,
    hadamard_learnable_scale=True,
    # GAKD configuration
    knowledge_distillation=True,
    gakd_alpha=0.7,
    gakd_beta=0.2,
    gakd_gamma=0.1,
    # DR-QAP configuration
    dynamic_regularization=True,
    qap_initial_lambda=0.1,
    # ESTE-M configuration
    enhanced_ste=True,
    ste_momentum=0.9,
)

# Create model and trainer
model = bitnet_v3.BitNetV3Model(config)
trainer = bitnet_v3.BitNetV3Trainer(model, config)

# Load teacher model for knowledge distillation
teacher_model = torch.load("teacher_model.pth")
trainer.set_teacher_model(teacher_model)

# Train with all features
trainer.train(
    train_dataloader,
    val_dataloader,
    num_epochs=70,
    save_every=5,
    eval_every=1,
)

🏗️ Architecture Overview

Core Components

  • bitnet_v3.core - Core quantization functions and utilities
  • bitnet_v3.modules - Individual innovation modules (MPQ, AHT-LP, GAKD, etc.)
  • bitnet_v3.models - Complete BitNet v3 model implementations
  • bitnet_v3.training - Training pipeline and utilities
  • bitnet_v3.utils - Configuration, logging, and metrics

Key Modules

# Enhanced H-BitLinear with all innovations
linear_layer = bitnet_v3.EnhancedHBitLinear(
    in_features=2048,
    out_features=2048,
    bias=False,
    adaptive_hadamard=True,
    progressive_quantization=True,
)

# Multi-stage Progressive Quantizer
mpq = bitnet_v3.MultiStageProgressiveQuantizer(
    stages=[8, 4, 2, 1.58],
    stage_epochs=[20, 20, 15, 15],
)

# Adaptive Hadamard Transform
aht = bitnet_v3.AdaptiveHadamardTransform(
    size=2048,
    learnable_params=True,
)

# Gradient-Aware Knowledge Distillation
gakd = bitnet_v3.GradientAwareKnowledgeDistillation(
    alpha=0.7,  # KL divergence weight
    beta=0.2,   # Gradient alignment weight  
    gamma=0.1,  # Feature alignment weight
)

📚 Detailed Documentation

Multi-Stage Progressive Quantization (MPQ)

MPQ gradually reduces bit-width during training, allowing models to adapt smoothly:

# Configure MPQ stages
mpq_config = {
    "stages": [
        {"start_epoch": 1, "end_epoch": 20, "bits": 8},
        {"start_epoch": 21, "end_epoch": 40, "bits": 4},
        {"start_epoch": 41, "end_epoch": 55, "bits": 2},
        {"start_epoch": 56, "end_epoch": 70, "bits": 1.58},
    ],
    "temperature_schedule": "linear",  # or "cosine"
}

scheduler = bitnet_v3.MPQScheduler(**mpq_config)

Adaptive Hadamard Transform (AHT-LP)

Enhanced Hadamard transformation with learnable parameters:

# Standard Hadamard transform
x_transformed = bitnet_v3.hadamard_transform(x)

# Adaptive Hadamard with learnable parameters
aht = bitnet_v3.AdaptiveHadamardTransform(
    size=x.size(-1),
    learnable_scale=True,
    learnable_shift=True,
)
x_adaptive = aht(x)

Gradient-Aware Knowledge Distillation (GAKD)

Preserves gradient information during distillation:

# Set up GAKD
gakd_loss = bitnet_v3.GradientAwareKnowledgeDistillation(
    alpha=0.7,  # Output distribution weight
    beta=0.2,   # Gradient alignment weight
    gamma=0.1,  # Feature alignment weight
)

# Compute distillation loss
loss = gakd_loss(
    student_outputs,
    teacher_outputs,
    student_features,
    teacher_features,
    student_gradients,
    teacher_gradients,
)

🧪 Examples

Training from Scratch

import bitnet_v3
from torch.utils.data import DataLoader

# Load your dataset
train_dataset = YourDataset("train")
train_loader = DataLoader(train_dataset, batch_size=256)

# Create model with default config
model = bitnet_v3.create_model(
    vocab_size=len(tokenizer),
    hidden_size=2048,
    num_layers=24,
)

# Train with MPQ
trainer = bitnet_v3.create_trainer(model)
trainer.train(train_loader, num_epochs=70)

Fine-tuning Pre-trained Model

# Load pre-trained model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/model")

# Convert to BitNet v3 with progressive quantization
bitnet_model = bitnet_v3.convert_to_bitnet_v3(
    model,
    enable_all_features=True,
)

# Fine-tune with knowledge distillation
trainer = bitnet_v3.create_trainer(bitnet_model)
trainer.set_teacher_model(model)  # Use original as teacher
trainer.train(fine_tune_loader, num_epochs=20)

Inference

# Load trained BitNet v3 model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/bitnet_v3_model")

# Generate text
output = model.generate(
    input_ids,
    max_length=100,
    temperature=0.7,
    do_sample=True,
)

🔬 Research Paper Implementation

This implementation includes all techniques from the original BitNet v3 research paper:

Quantization Functions

  • Ternary weight quantization: {-1, 0, 1}
  • 4-bit activation quantization with Hadamard transform
  • AbsMean and AbsMax quantization schemes

Training Innovations

  • Progressive bit-width reduction schedule
  • Temperature-based quantization transitions
  • Gradient-aware loss computation
  • Dynamic regularization with layer sensitivity

Mathematical Formulations

All key equations from the paper are implemented:

# Temperature-based transition (Equation 1)
Q_t(x) = σ(β_t) * Q_b_t(x) + (1 - σ(β_t)) * Q_b_{t-1}(x)

# Adaptive Hadamard transform (Equation 2)  
H_adaptive(x) = γ  (H_m · x) + β

# GAKD loss (Equation 3)
L_GAKD = α*L_KL + β*L_grad + γ*L_feature

# Dynamic regularization (Equation 4)
R_QAP = λ(t) * Σ ω_i ||W_i - Q(W_i)||²

📈 Evaluation and Metrics

Built-in evaluation tools for comprehensive analysis:

# Compute perplexity
ppl = bitnet_v3.compute_perplexity(model, test_loader)

# Efficiency metrics (theoretical analysis)
metrics = bitnet_v3.compute_efficiency_metrics(
    bitnet_model, 
    baseline_model,
    test_input,
)
print(f"Expected speedup: {metrics['speedup']:.1f}x")
print(f"Expected memory reduction: {metrics['memory_reduction']:.1f}%")

# Downstream task evaluation framework
results = bitnet_v3.evaluate_downstream_tasks(
    model,
    tasks=["hellaswag", "mmlu", "truthfulqa"],
)

Note: Performance validation is ongoing. We encourage the community to help benchmark BitNet v3 across different tasks and model sizes.

🛡️ Testing

Run the comprehensive test suite:

# Run all tests
pytest tests/

# Run specific test modules
pytest tests/test_modules/test_mpq.py
pytest tests/test_modules/test_gakd.py

# Run with coverage
pytest --cov=bitnet_v3 tests/

🤝 Contributing - We Need Your Help!

BitNet v3 is an active research project and we're actively seeking contributors! Whether you're a researcher, engineer, or enthusiast, there are many ways to contribute to advancing 1-bit LLM technology.

🚨 High Priority Contributions Needed

  • 🧪 Performance Benchmarking: Help us validate BitNet v3 across different model sizes (1B, 3B, 7B+)
  • 📊 Dataset Testing: Test on various datasets (language modeling, downstream tasks, multilingual)
  • ⚡ Optimization: CUDA kernels, memory optimizations, training speed improvements
  • 🔧 Integration: HuggingFace Transformers integration, ONNX export, deployment tools
  • 📝 Documentation: Tutorials, guides, and improved examples
  • 🐛 Bug Reports: Help us identify and fix issues in the codebase

🎯 Research Opportunities

  • Compare against other quantization methods (GPTQ, AWQ, etc.)
  • Explore different MPQ schedules and temperature functions
  • Investigate GAKD effectiveness across model architectures
  • Test on specialized domains (code, math, science)
  • Efficiency analysis on different hardware (GPUs, edge devices)

💡 Easy Ways to Get Started

  1. Run the examples and report any issues
  2. Test installation on different systems (Windows, Mac, Linux)
  3. Improve documentation - add docstrings, fix typos, clarify explanations
  4. Add unit tests for untested modules
  5. Create tutorials for specific use cases

🛠 Development Setup

# Fork the repository on GitHub first!
git clone https://github.com/ProCreations-Official/bitnet-v3.git
cd bitnet-v3

# Install in development mode
pip install -e ".[dev]"

# Set up pre-commit hooks
pre-commit install

# Run tests to ensure everything works
pytest tests/

📋 Contribution Guidelines

  • All skill levels welcome - from typo fixes to major algorithmic improvements
  • Research-first approach - we prioritize correctness and reproducibility
  • Open communication - discuss ideas in GitHub Issues before major changes
  • Documentation required - all new features need documentation and examples
  • Testing encouraged - add tests for new functionality when possible

🌟 Recognition

Contributors will be:

  • Added to the contributors list in the README
  • Acknowledged in any resulting research papers
  • Invited to collaborate on follow-up research

📞 Get in Touch

Every contribution matters - from fixing a typo to implementing a new feature. Join us in making 1-bit LLMs a reality! 🚀

📄 Citation

If you use BitNet v3 in your research, please cite this repository:

@software{bitnet_v3_2024,
  title={BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform},
  author={ProCreations},
  url={https://github.com/ProCreations-Official/bitnet-v3},
  year={2024}
}

📜 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

  • Built upon the foundation of BitNet and BitNet b1.58 from Microsoft Research
  • Inspired by advances in quantization-aware training and knowledge distillation
  • Thanks to the PyTorch team for the excellent deep learning framework

📞 Support


BitNet v3 - Bringing 1-bit LLMs closer to practical deployment! 🚀

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bitnet_v3-1.0.1.tar.gz (60.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bitnet_v3-1.0.1-py3-none-any.whl (57.6 kB view details)

Uploaded Python 3

File details

Details for the file bitnet_v3-1.0.1.tar.gz.

File metadata

  • Download URL: bitnet_v3-1.0.1.tar.gz
  • Upload date:
  • Size: 60.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for bitnet_v3-1.0.1.tar.gz
Algorithm Hash digest
SHA256 a44378fbdd0a4098120fe98e8f80d87d45e1371096cd08ee06e07fa843bcaf1a
MD5 9b94f74ee3876d5a148bc9e5d06b74a6
BLAKE2b-256 b6abbefead7ea73a138905d88863b8ea79199cb6d68404bb7003161f6b430a0d

See more details on using hashes here.

File details

Details for the file bitnet_v3-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: bitnet_v3-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 57.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for bitnet_v3-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 98f97acf2b2fa3cc776c5416294217c6813078e14d2986558758d7a48733b654
MD5 fe87bd4870c4729034f94c7c46c3dac0
BLAKE2b-256 efa608cf118d40084b19df0ad81728527e2b5e4b378b81340671f7cf706811fd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page