Skip to main content

BitNet 1.58-bit Masked Diffusion Language Model

Project description

BitNet-DLLM

BitNet-DLLM is a highly scalable, memory-efficient PyTorch framework for training and inferencing Large Language Models utilizing Masked Diffusion and 1.58-bit (ternary) weight quantization (BitNet b1.58).

Designed for enterprise-grade scalability, the library is built on top of the Hugging Face ecosystem (accelerate, datasets) to support multi-GPU Distributed Data Parallel (DDP), Automatic Mixed Precision (AMP), and zero-copy dataset streaming, preventing Out-Of-Memory (OOM) errors even on massive datasets like RedPajama.

✨ Key Features

  • 1.58-bit Quantization (BitLinear): Drastically reduces VRAM footprint and memory bandwidth requirements by constraining weights to {-1, 0, 1} without sacrificing performance.
  • Masked Diffusion Language Modeling (MDLM): Replaces traditional auto-regressive generation with a highly parallelizable discrete diffusion process.
  • Scale-Out Ready: Fully integrated with 🤗 accelerate. Hardware-agnostic training out-of-the-box (CPU, single-GPU, multi-GPU, DeepSpeed).
  • Zero-Copy Data Loading: Uses Apache Arrow under the hood via 🤗 datasets, allowing you to train on terabytes of data with minimal RAM overhead.
  • Clean Architecture: Follows SOLID principles with decoupled models, optimizers, and a DRY-compliant denoising sampler.

📦 Installation

Requirements:

  • Python 3.10+
  • PyTorch 2.0+

You can install bitnet-dllm via pip:

# Install from PyPI (Coming soon)
pip install bitnet-dllm

# Or install from source
git clone https://github.com/your-repo/bitnet-dllm.git
cd bitnet-dllm
pip install -e .

⚡ Quickstart

1. Model Initialization

BitNet-DLLM uses a clean configuration object. The model handles purely the forward pass, keeping architecture and training logic strictly separated.

from bitnet_dllm.config import BitDiffLMConfig
from bitnet_dllm.model import BitDiffLM

# Define the architecture
config = BitDiffLMConfig(
    vocab_size=32000,
    hidden_size=1024,
    num_layers=12,
    num_heads=8,
    max_seq_len=512,
    mask_token_id=0,
    pad_token_id=1
)

# Initialize the 1.58-bit Diffusion Model
model = BitDiffLM(config)

2. Text Generation (Inference)

The MDLMAncestralSampler wraps the complex diffusion math into an easy-to-use API. It shares a core denoising loop to ensure consistency across standard generation and mask-filling.

from bitnet_dllm.sampler import MDLMAncestralSampler
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
sampler = MDLMAncestralSampler(model, tokenizer, device="cuda")

# Generate text from a prompt
generated_text = sampler.generate(
    prompt="The future of 1-bit LLMs is",
    num_steps=50,       # Diffusion steps
    temperature=0.8,
    top_p=0.9
)
print(generated_text)

🚂 End-to-End Training Pipeline

Creating a model from scratch is simple. Thanks to the integrated TrainingConfig and BitDiffLMTrainer, you can easily scale from a local Jupyter Notebook to a multi-node GPU cluster.

Here is a complete, runnable pipeline:

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from bitnet_dllm.config import BitDiffLMConfig
from bitnet_dllm.model import BitDiffLM
from bitnet_dllm.dataset import MaskedDiffusionDataset
from bitnet_dllm.utils import get_optimizer_groups
from bitnet_dllm.trainer import BitDiffLMTrainer, TrainingConfig

# ==========================================
# 1. Setup Tokenizer & Model
# ==========================================
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
config = BitDiffLMConfig(
    vocab_size=tokenizer.vocab_size,
    mask_token_id=tokenizer.mask_token_id,
    pad_token_id=tokenizer.pad_token_id,
    hidden_size=512,
    num_layers=6,
    num_heads=8,
)
model = BitDiffLM(config)

# ==========================================
# 2. Zero-Copy Dataset & DataLoader
# ==========================================
# Load data dynamically without filling up RAM
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_fn(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_ds = raw_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])

# Wrap with our Diffusion Dataset
train_dataset = MaskedDiffusionDataset(
    hf_dataset=tokenized_ds,
    mask_token_id=config.mask_token_id,
    t_min=1e-4,
    t_max=1.0,
)

# Use get_collate_fn() for extremely fast, vectorized batch masking
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=train_dataset.get_collate_fn(),
)

# ==========================================
# 3. Optimizer Setup (Inversion of Control)
# ==========================================
# Automatically split bit-linear parameters from high-precision norm layers
optim_groups = get_optimizer_groups(model, learning_rate=3e-4, weight_decay=0.01)
optimizer = torch.optim.AdamW(optim_groups)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

# ==========================================
# 4. Distributed Training via Accelerate
# ==========================================
train_config = TrainingConfig(
    learning_rate=3e-4,
    ema_decay=0.9999,
    grad_accum=2,
    grad_clip=1.0,
    epochs=10,
)

trainer = BitDiffLMTrainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    config=train_config,
)

# Start training (handles DDP, mixed precision, and EMA automatically)
trainer.train()

Tip: To run the above script across multiple GPUs, simply execute it via accelerate launch train.py

🏗️ Architecture Overview

Module Description
bitnet_dllm.bitlinear Contains the core BitLinear implementation representing weights in 1.58 bits {-1, 0, 1}.
bitnet_dllm.model Defines the BitDiffLM pure PyTorch nn.Module.
bitnet_dllm.trainer An accelerate-backed trainer handling gradient accumulation, Exponential Moving Average (EMA) updates asynchronously, and checkpointing.
bitnet_dllm.dataset Optimized DataLoaders utilizing Hugging Face datasets and dynamic, vectorized batch masking to eliminate CPU bottlenecks.
bitnet_dllm.sampler Unified diffusion logic implementing _denoise_loop for DRY inference.

📜 License & Acknowledgements

This project is licensed under the MIT License.

The 1.58-bit quantization is heavily inspired by The Era of 1-bit LLMs (Microsoft) and the BitNet architecture.

The Masked Diffusion implementation is inspired by recent advances in Discrete Diffusion models for sequence generation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bitnet_dllm-1.2.0.tar.gz (22.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bitnet_dllm-1.2.0-py3-none-any.whl (20.4 kB view details)

Uploaded Python 3

File details

Details for the file bitnet_dllm-1.2.0.tar.gz.

File metadata

  • Download URL: bitnet_dllm-1.2.0.tar.gz
  • Upload date:
  • Size: 22.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for bitnet_dllm-1.2.0.tar.gz
Algorithm Hash digest
SHA256 3b307d82ce55998fc73603290b7366854522db99d9611611150a03a76751025b
MD5 b1910f2b9e40195337b50d3816982cea
BLAKE2b-256 902c7e0cccd72702f8e30aebe3d3f86d3fdf8f228ab0d1be154c9a69bc08cd17

See more details on using hashes here.

File details

Details for the file bitnet_dllm-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: bitnet_dllm-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 20.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for bitnet_dllm-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3cc1e08ccd56350c3d601f5a45bc39653e41392553322bbcd41ec9b591b429e7
MD5 aefc56caa5b9bd8f70454a09bf8c5f2e
BLAKE2b-256 80e43f105ff72642916fd15d28e35e4905cd9e7b9c2117868296d8310a9cb522

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page