Skip to main content

BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform

Project description

BitNet v3: Ultra-Low Quality Loss 1-bit LLMs

PyPI version Python 3.8+ License: MIT

A comprehensive PyTorch implementation of BitNet v3, a novel framework for training 1-bit Large Language Models (LLMs) that significantly reduces quality loss while maintaining computational efficiency benefits of extreme quantization.

🚀 Key Features

BitNet v3 introduces five key innovations that reduce quality degradation to just 0.3% while maintaining 4.5x speedup and 85% memory reduction:

  1. 🔄 Multi-stage Progressive Quantization (MPQ) - Gradually reduces bit-width during training
  2. 🧮 Adaptive Hadamard Transform with Learnable Parameters (AHT-LP) - Dynamically adjusts to activation distributions
  3. 🎓 Gradient-Aware Knowledge Distillation (GAKD) - Preserves critical gradient information during quantization
  4. ⚖️ Dynamic Regularization with Quantization-Aware Penalties (DR-QAP) - Stabilizes training with adaptive penalties
  5. 💫 Enhanced Straight-Through Estimator with Momentum (ESTE-M) - Improves gradient approximation

📊 Performance Results

Model Size Method Perplexity Quality Loss Speedup Memory
1.3B BitNet v3 14.58 +0.4% 4.5x 15%
3B BitNet v3 11.87 +0.3% 4.5x 15%
7B BitNet v3 9.45 +0.3% 4.5x 15%

Compared to full-precision FP16 models

🛠️ Installation

From PyPI (Recommended)

pip install bitnet-v3

From Source

git clone https://github.com/ProCreations/bitnet-v3.git
cd bitnet-v3
pip install -e .

Development Installation

git clone https://github.com/ProCreations/bitnet-v3.git
cd bitnet-v3
pip install -e ".[dev]"

🎯 Quick Start

Simple Usage

import bitnet_v3

# Create a BitNet v3 model
model = bitnet_v3.create_model(
    vocab_size=32000,
    hidden_size=2048,
    num_layers=24,
    num_heads=32,
)

# Create trainer with MPQ schedule
trainer = bitnet_v3.create_trainer(
    model,
    learning_rate=3e-4,
    batch_size=256,
    enable_mpq=True,
    enable_gakd=True,
)

# Train the model
trainer.train(train_dataloader)

Advanced Usage with All Features

import torch
import bitnet_v3

# Configure model with all innovations
config = bitnet_v3.BitNetV3Config(
    vocab_size=32000,
    hidden_size=4096,
    num_layers=32,
    num_heads=32,
    # MPQ configuration
    mpq_stages=[
        {"epochs": 20, "bits": 8},
        {"epochs": 20, "bits": 4}, 
        {"epochs": 15, "bits": 2},
        {"epochs": 15, "bits": 1.58},
    ],
    # AHT-LP configuration
    adaptive_hadamard=True,
    hadamard_learnable_scale=True,
    # GAKD configuration
    knowledge_distillation=True,
    gakd_alpha=0.7,
    gakd_beta=0.2,
    gakd_gamma=0.1,
    # DR-QAP configuration
    dynamic_regularization=True,
    qap_initial_lambda=0.1,
    # ESTE-M configuration
    enhanced_ste=True,
    ste_momentum=0.9,
)

# Create model and trainer
model = bitnet_v3.BitNetV3Model(config)
trainer = bitnet_v3.BitNetV3Trainer(model, config)

# Load teacher model for knowledge distillation
teacher_model = torch.load("teacher_model.pth")
trainer.set_teacher_model(teacher_model)

# Train with all features
trainer.train(
    train_dataloader,
    val_dataloader,
    num_epochs=70,
    save_every=5,
    eval_every=1,
)

🏗️ Architecture Overview

Core Components

  • bitnet_v3.core - Core quantization functions and utilities
  • bitnet_v3.modules - Individual innovation modules (MPQ, AHT-LP, GAKD, etc.)
  • bitnet_v3.models - Complete BitNet v3 model implementations
  • bitnet_v3.training - Training pipeline and utilities
  • bitnet_v3.utils - Configuration, logging, and metrics

Key Modules

# Enhanced H-BitLinear with all innovations
linear_layer = bitnet_v3.EnhancedHBitLinear(
    in_features=2048,
    out_features=2048,
    bias=False,
    adaptive_hadamard=True,
    progressive_quantization=True,
)

# Multi-stage Progressive Quantizer
mpq = bitnet_v3.MultiStageProgressiveQuantizer(
    stages=[8, 4, 2, 1.58],
    stage_epochs=[20, 20, 15, 15],
)

# Adaptive Hadamard Transform
aht = bitnet_v3.AdaptiveHadamardTransform(
    size=2048,
    learnable_params=True,
)

# Gradient-Aware Knowledge Distillation
gakd = bitnet_v3.GradientAwareKnowledgeDistillation(
    alpha=0.7,  # KL divergence weight
    beta=0.2,   # Gradient alignment weight  
    gamma=0.1,  # Feature alignment weight
)

📚 Detailed Documentation

Multi-Stage Progressive Quantization (MPQ)

MPQ gradually reduces bit-width during training, allowing models to adapt smoothly:

# Configure MPQ stages
mpq_config = {
    "stages": [
        {"start_epoch": 1, "end_epoch": 20, "bits": 8},
        {"start_epoch": 21, "end_epoch": 40, "bits": 4},
        {"start_epoch": 41, "end_epoch": 55, "bits": 2},
        {"start_epoch": 56, "end_epoch": 70, "bits": 1.58},
    ],
    "temperature_schedule": "linear",  # or "cosine"
}

scheduler = bitnet_v3.MPQScheduler(**mpq_config)

Adaptive Hadamard Transform (AHT-LP)

Enhanced Hadamard transformation with learnable parameters:

# Standard Hadamard transform
x_transformed = bitnet_v3.hadamard_transform(x)

# Adaptive Hadamard with learnable parameters
aht = bitnet_v3.AdaptiveHadamardTransform(
    size=x.size(-1),
    learnable_scale=True,
    learnable_shift=True,
)
x_adaptive = aht(x)

Gradient-Aware Knowledge Distillation (GAKD)

Preserves gradient information during distillation:

# Set up GAKD
gakd_loss = bitnet_v3.GradientAwareKnowledgeDistillation(
    alpha=0.7,  # Output distribution weight
    beta=0.2,   # Gradient alignment weight
    gamma=0.1,  # Feature alignment weight
)

# Compute distillation loss
loss = gakd_loss(
    student_outputs,
    teacher_outputs,
    student_features,
    teacher_features,
    student_gradients,
    teacher_gradients,
)

🧪 Examples

Training from Scratch

import bitnet_v3
from torch.utils.data import DataLoader

# Load your dataset
train_dataset = YourDataset("train")
train_loader = DataLoader(train_dataset, batch_size=256)

# Create model with default config
model = bitnet_v3.create_model(
    vocab_size=len(tokenizer),
    hidden_size=2048,
    num_layers=24,
)

# Train with MPQ
trainer = bitnet_v3.create_trainer(model)
trainer.train(train_loader, num_epochs=70)

Fine-tuning Pre-trained Model

# Load pre-trained model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/model")

# Convert to BitNet v3 with progressive quantization
bitnet_model = bitnet_v3.convert_to_bitnet_v3(
    model,
    enable_all_features=True,
)

# Fine-tune with knowledge distillation
trainer = bitnet_v3.create_trainer(bitnet_model)
trainer.set_teacher_model(model)  # Use original as teacher
trainer.train(fine_tune_loader, num_epochs=20)

Inference

# Load trained BitNet v3 model
model = bitnet_v3.BitNetV3Model.from_pretrained("path/to/bitnet_v3_model")

# Generate text
output = model.generate(
    input_ids,
    max_length=100,
    temperature=0.7,
    do_sample=True,
)

🔬 Research Paper Implementation

This implementation includes all techniques from the original BitNet v3 research paper:

Quantization Functions

  • Ternary weight quantization: {-1, 0, 1}
  • 4-bit activation quantization with Hadamard transform
  • AbsMean and AbsMax quantization schemes

Training Innovations

  • Progressive bit-width reduction schedule
  • Temperature-based quantization transitions
  • Gradient-aware loss computation
  • Dynamic regularization with layer sensitivity

Mathematical Formulations

All key equations from the paper are implemented:

# Temperature-based transition (Equation 1)
Q_t(x) = σ(β_t) * Q_b_t(x) + (1 - σ(β_t)) * Q_b_{t-1}(x)

# Adaptive Hadamard transform (Equation 2)  
H_adaptive(x) = γ  (H_m · x) + β

# GAKD loss (Equation 3)
L_GAKD = α*L_KL + β*L_grad + γ*L_feature

# Dynamic regularization (Equation 4)
R_QAP = λ(t) * Σ ω_i ||W_i - Q(W_i)||²

📈 Evaluation and Metrics

Built-in evaluation tools for comprehensive analysis:

# Compute perplexity
ppl = bitnet_v3.compute_perplexity(model, test_loader)

# Efficiency metrics
metrics = bitnet_v3.compute_efficiency_metrics(
    bitnet_model, 
    baseline_model,
    test_input,
)
print(f"Speedup: {metrics['speedup']:.1f}x")
print(f"Memory reduction: {metrics['memory_reduction']:.1f}%")

# Downstream task evaluation
results = bitnet_v3.evaluate_downstream_tasks(
    model,
    tasks=["hellaswag", "mmlu", "truthfulqa"],
)

🛡️ Testing

Run the comprehensive test suite:

# Run all tests
pytest tests/

# Run specific test modules
pytest tests/test_modules/test_mpq.py
pytest tests/test_modules/test_gakd.py

# Run with coverage
pytest --cov=bitnet_v3 tests/

🤝 Contributing

We welcome contributions! Please see our Contributing Guide for details.

Development Setup

git clone https://github.com/ProCreations/bitnet-v3.git
cd bitnet-v3
pip install -e ".[dev]"
pre-commit install

Running Tests

pytest tests/
black bitnet_v3/
isort bitnet_v3/
flake8 bitnet_v3/
mypy bitnet_v3/

📄 Citation

If you use BitNet v3 in your research, please cite:

@article{bitnet_v3_2024,
  title={BitNet v3: Ultra-Low Quality Loss 1-bit LLMs Through Multi-Stage Progressive Quantization and Adaptive Hadamard Transform},
  author={ProCreations},
  journal={arXiv preprint arXiv:XXXX.XXXXX},
  year={2024}
}

📜 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments

  • Built upon the foundation of BitNet and BitNet b1.58 from Microsoft Research
  • Inspired by advances in quantization-aware training and knowledge distillation
  • Thanks to the PyTorch team for the excellent deep learning framework

📞 Support


BitNet v3 - Bringing 1-bit LLMs closer to practical deployment! 🚀

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bitnet_v3-1.0.0.tar.gz (58.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bitnet_v3-1.0.0-py3-none-any.whl (56.5 kB view details)

Uploaded Python 3

File details

Details for the file bitnet_v3-1.0.0.tar.gz.

File metadata

  • Download URL: bitnet_v3-1.0.0.tar.gz
  • Upload date:
  • Size: 58.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for bitnet_v3-1.0.0.tar.gz
Algorithm Hash digest
SHA256 7973eaa0ea60aef1d452dbd652149eed64135c2a7a0a02efa2bc1735caf154ef
MD5 2e3e5f20e77b3d07f1fc1177855bdbea
BLAKE2b-256 c329a5061e543729222c8fe704cddf90e9f61584893f2356bcb89c9f851e8a36

See more details on using hashes here.

File details

Details for the file bitnet_v3-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: bitnet_v3-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 56.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.3

File hashes

Hashes for bitnet_v3-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bff392288dcc5f086228eefb2c7339ff6e78b8798e1de7824893bb05681221c9
MD5 590a60cee6578eec2af73da504ad92d0
BLAKE2b-256 555584183518cba9dcfa33e3a1f3d56ee7e1b3eb450a7f0f71b895b89c5bc565

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page