Skip to main content

BitNet 1.58-bit Masked Diffusion Language Model

Project description

BitNet-DLLM

BitNet-DLLM is a highly scalable, memory-efficient PyTorch framework for training and inferencing Large Language Models utilizing Masked Diffusion and 1.58-bit (ternary) weight quantization (BitNet b1.58).

Designed for enterprise-grade scalability, the library is built on top of the Hugging Face ecosystem (accelerate, datasets) to support multi-GPU Distributed Data Parallel (DDP), Automatic Mixed Precision (AMP), and zero-copy dataset streaming, preventing Out-Of-Memory (OOM) errors even on massive datasets like RedPajama.

✨ Key Features

  • 1.58-bit Quantization (BitLinear): Drastically reduces VRAM footprint and memory bandwidth requirements by constraining weights to {-1, 0, 1} without sacrificing performance.
  • Masked Diffusion Language Modeling (MDLM): Replaces traditional auto-regressive generation with a highly parallelizable discrete diffusion process.
  • Scale-Out Ready: Fully integrated with 🤗 accelerate. Hardware-agnostic training out-of-the-box (CPU, single-GPU, multi-GPU, DeepSpeed).
  • Zero-Copy Data Loading: Uses Apache Arrow under the hood via 🤗 datasets, allowing you to train on terabytes of data with minimal RAM overhead.
  • Clean Architecture: Follows SOLID principles with decoupled models, optimizers, and a DRY-compliant denoising sampler.

📦 Installation

Requirements:

  • Python 3.10+
  • PyTorch 2.0+

You can install bitnet-dllm via pip:

# Install from PyPI (Coming soon)
pip install bitnet-dllm

# Or install from source
git clone https://github.com/your-repo/bitnet-dllm.git
cd bitnet-dllm
pip install -e .

⚡ Quickstart

1. Model Initialization

BitNet-DLLM uses a clean configuration object. The model handles purely the forward pass, keeping architecture and training logic strictly separated.

from bitnet_dllm.config import BitDiffLMConfig
from bitnet_dllm.model import BitDiffLM

# Define the architecture
config = BitDiffLMConfig(
    vocab_size=32000,
    hidden_size=1024,
    num_layers=12,
    num_heads=8,
    max_seq_len=512,
    mask_token_id=0,
    pad_token_id=1
)

# Initialize the 1.58-bit Diffusion Model
model = BitDiffLM(config)

2. Text Generation (Inference)

The MDLMAncestralSampler wraps the complex diffusion math into an easy-to-use API. It shares a core denoising loop to ensure consistency across standard generation and mask-filling.

from bitnet_dllm.sampler import MDLMAncestralSampler
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
sampler = MDLMAncestralSampler(model, tokenizer, device="cuda")

# Generate text from a prompt
generated_text = sampler.generate(
    prompt="The future of 1-bit LLMs is",
    num_steps=50,       # Diffusion steps
    temperature=0.8,
    top_p=0.9
)
print(generated_text)

🚂 End-to-End Training Pipeline

Creating a model from scratch is simple. Thanks to the integrated TrainingConfig and BitDiffLMTrainer, you can easily scale from a local Jupyter Notebook to a multi-node GPU cluster.

Here is a complete, runnable pipeline:

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from bitnet_dllm.config import BitDiffLMConfig
from bitnet_dllm.model import BitDiffLM
from bitnet_dllm.dataset import MaskedDiffusionDataset
from bitnet_dllm.utils import get_optimizer_groups
from bitnet_dllm.trainer import BitDiffLMTrainer, TrainingConfig

# ==========================================
# 1. Setup Tokenizer & Model
# ==========================================
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
config = BitDiffLMConfig(
    vocab_size=tokenizer.vocab_size,
    mask_token_id=tokenizer.mask_token_id,
    pad_token_id=tokenizer.pad_token_id,
    hidden_size=512,
    num_layers=6,
    num_heads=8,
)
model = BitDiffLM(config)

# ==========================================
# 2. Zero-Copy Dataset & DataLoader
# ==========================================
# Load data dynamically without filling up RAM
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_fn(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_ds = raw_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])

# Wrap with our Diffusion Dataset
train_dataset = MaskedDiffusionDataset(
    hf_dataset=tokenized_ds,
    mask_token_id=config.mask_token_id,
    t_min=1e-4,
    t_max=1.0,
)

# Use get_collate_fn() for extremely fast, vectorized batch masking
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=train_dataset.get_collate_fn(),
)

# ==========================================
# 3. Optimizer Setup (Inversion of Control)
# ==========================================
# Automatically split bit-linear parameters from high-precision norm layers
optim_groups = get_optimizer_groups(model, learning_rate=3e-4, weight_decay=0.01)
optimizer = torch.optim.AdamW(optim_groups)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

# ==========================================
# 4. Distributed Training via Accelerate
# ==========================================
train_config = TrainingConfig(
    learning_rate=3e-4,
    ema_decay=0.9999,
    grad_accum=2,
    grad_clip=1.0,
    epochs=10,
)

trainer = BitDiffLMTrainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    config=train_config,
)

# Start training (handles DDP, mixed precision, and EMA automatically)
trainer.train()

Tip: To run the above script across multiple GPUs, simply execute it via accelerate launch train.py

🏗️ Architecture Overview

Module Description
bitnet_dllm.bitlinear Contains the core BitLinear implementation representing weights in 1.58 bits {-1, 0, 1}.
bitnet_dllm.model Defines the BitDiffLM pure PyTorch nn.Module.
bitnet_dllm.trainer An accelerate-backed trainer handling gradient accumulation, Exponential Moving Average (EMA) updates asynchronously, and checkpointing.
bitnet_dllm.dataset Optimized DataLoaders utilizing Hugging Face datasets and dynamic, vectorized batch masking to eliminate CPU bottlenecks.
bitnet_dllm.sampler Unified diffusion logic implementing _denoise_loop for DRY inference.

📜 License & Acknowledgements

This project is licensed under the MIT License.

The 1.58-bit quantization is heavily inspired by The Era of 1-bit LLMs (Microsoft) and the BitNet architecture.

The Masked Diffusion implementation is inspired by recent advances in Discrete Diffusion models for sequence generation.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bitnet_dllm-1.0.1.tar.gz (22.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bitnet_dllm-1.0.1-py3-none-any.whl (20.2 kB view details)

Uploaded Python 3

File details

Details for the file bitnet_dllm-1.0.1.tar.gz.

File metadata

  • Download URL: bitnet_dllm-1.0.1.tar.gz
  • Upload date:
  • Size: 22.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for bitnet_dllm-1.0.1.tar.gz
Algorithm Hash digest
SHA256 7c9365c4d06c4f6ac92f3932999f27fff660471fbc7db8ff563bfe6340ca923c
MD5 0c2b17cf496d4c55ec8086228961b85a
BLAKE2b-256 05ca83480bbbdfc8cb0a501e0505d3369bc77aa632820615d632c705eb7f30ae

See more details on using hashes here.

File details

Details for the file bitnet_dllm-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: bitnet_dllm-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 20.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for bitnet_dllm-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 11145c336dfb6a3c9e8d2f4de463a16f946d18172ba61479238aca29e1b6d2a1
MD5 1c43ce5ef99b90dd31c6551e1221be86
BLAKE2b-256 bc5de1f61eb2e3b070d51172f4ef8a82c4b41833a79639654d60a28b2abab489

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page