bitnet - Pytorch
Project description
BitNet
Implementation of the "BitNet: Scaling 1-bit Transformers for Large Language Models"
BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant
"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules!
NEWS
- BitNet Transformer has been trained using the
train.py
file that trains on enwiki8 a small 1gb dataset of wikipedia: HERE IS THE LINK
Installation
pip install bitnet
Usage:
- Example of the BitLinear layer which is the main innovation of the paper!
import torch
from bitnet import BitLinear
# Input
x = torch.randn(10, 512)
# BitLinear layer
layer = BitLinear(512, 400)
# Output
y = layer(x)
print(y)
- Running random inputs to a full BitNet Transformer as shown in paper:
import torch
from bitnet import BitNetTransformer
bitnet = BitNetTransformer(
num_tokens=20000,
dim=512,
depth=6,
dim_head=64,
heads=8,
ff_mult=4,
)
tokens = torch.randint(0, 20000, (1, 512))
logits = bitnet(tokens)
print(logits.shape)
Inference
from bitnet import BitNetInference
bitnet = BitNetInference()
bitnet.load_model('../model_checkpoint.pth') #Download model
output_str = bitnet.generate("The dog jumped over the ", 512)
print(output_str)
License
MIT
Citation
@misc{2310.11453,
Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models},
Year = {2023},
Eprint = {arXiv:2310.11453},
}
Todo
- Double check BitLinear implementation and make sure it works exactly as in paper
- Implement training script for
BitNetTransformer
- Train on Enwiki8, copy and past code and data from Lucidrains repos
- Benchmark peformance
- Look into Straight Through Estimator for non-differentiable backprop
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
bitnet-0.0.3.tar.gz
(7.8 kB
view hashes)
Built Distribution
bitnet-0.0.3-py3-none-any.whl
(8.4 kB
view hashes)