Skip to main content

bitnet - Pytorch

Project description

Multi-Modality

BitNet

bitnet Implementation of the "BitNet: Scaling 1-bit Transformers for Large Language Models"

Paper link:

BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant

"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules!

NEWS

  • BitNet Transformer has been trained using the train.py file that trains on enwiki8 a small 1gb dataset of wikipedia: HERE IS THE LINK

Installation

pip install bitnet

Usage:

  • Example of the BitLinear layer which is the main innovation of the paper!
import torch

from bitnet import BitLinear

# Input
x = torch.randn(10, 512)

# BitLinear layer
layer = BitLinear(512, 400)

# Output
y = layer(x)

print(y)

  • Running random inputs to a full BitNet Transformer as shown in paper:
import torch
from bitnet import BitNetTransformer

bitnet = BitNetTransformer(
    num_tokens=20000,
    dim=512,
    depth=6,
    dim_head=64,
    heads=8,
    ff_mult=4,
)

tokens = torch.randint(0, 20000, (1, 512))
logits = bitnet(tokens)
print(logits.shape)

Inference

from bitnet import BitNetInference

bitnet = BitNetInference()
bitnet.load_model('../model_checkpoint.pth') #Download model
output_str = bitnet.generate("The dog jumped over the ", 512)
print(output_str)

License

MIT

Citation

@misc{2310.11453,
Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models},
Year = {2023},
Eprint = {arXiv:2310.11453},
}

Todo

  • Double check BitLinear implementation and make sure it works exactly as in paper
  • Implement training script for BitNetTransformer
  • Train on Enwiki8, copy and past code and data from Lucidrains repos
  • Benchmark peformance
  • Look into Straight Through Estimator for non-differentiable backprop

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bitnet-0.0.3.tar.gz (7.8 kB view hashes)

Uploaded Source

Built Distribution

bitnet-0.0.3-py3-none-any.whl (8.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page