8-bit and 4-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4)
Project description
MPS BitsAndBytes
8-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4).
50% memory savings for storing model weights, with no speed penalty using smart caching.
Features
- Linear8bit: Drop-in replacement for
nn.Linearwith int8 weights - Smart caching: Dequantize once, run fast fp16 matmul (AMX accelerated)
- QLoRA ready: Perfect for fine-tuning large models on Mac
- Pure PyTorch: No custom kernels needed, works out of the box
Installation
pip install mps-bitsandbytes
Or from source:
git clone https://github.com/mpsops/mps-bitsandbytes
cd mps-bitsandbytes
pip install -e .
Quick Start
import torch
from mps_bitsandbytes import Linear8bit, quantize_model
# Convert existing model to 8-bit
model = YourModel().to('mps')
model = quantize_model(model, device='mps')
# Or convert individual layers
linear_8bit = Linear8bit.from_linear(some_linear_layer)
# Use normally - same API, 50% less memory for weights
output = model(input)
How It Works
- Storage: Weights stored as int8 (1 byte per param vs 2 bytes for fp16)
- First forward: Dequantize int8 → fp16, cache the result
- Subsequent forwards: Use cached fp16 weights, fast AMX matmul
This gives you:
- 50% memory savings on disk and when loading weights
- Same inference speed as fp16 (once cached)
- Compatible with QLoRA training
Memory Savings
| Model Size | FP16 | INT8 | Savings |
|---|---|---|---|
| 7B params | 14 GB | 7 GB | 7 GB |
| 13B params | 26 GB | 13 GB | 13 GB |
| 70B params | 140 GB | 70 GB | 70 GB |
Configuration
# Default: cache enabled (fast, uses memory during inference)
layer = Linear8bit.from_linear(linear, use_cache=True)
# Memory-constrained: no cache (slower, minimum memory)
layer = Linear8bit.from_linear(linear, use_cache=False)
# Clear cache to free memory
layer.clear_cache()
QLoRA Training
from mps_bitsandbytes import quantize_model
from peft import get_peft_model, LoraConfig
# Load model in 8-bit
model = AutoModelForCausalLM.from_pretrained("model_name")
model = quantize_model(model.to('mps'))
# Add LoRA adapters (these stay in fp16 for gradients)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)
# Train - base weights frozen in int8, LoRA in fp16
trainer.train()
Benchmarks
Tested on M1 Max, batch_size=32, hidden_dim=4096:
| Method | Forward Time | Memory |
|---|---|---|
| FP16 | 1.08 ms | 100 MB |
| INT8 (cached) | 0.98 ms | 50 MB + cache |
| INT8 (no cache) | 9.65 ms | 50 MB |
Limitations
- First forward is slower: Need to dequantize weights once
- Cache uses memory: During inference, cached fp16 weights use extra memory
- No int8 matmul acceleration: Apple Silicon AMX only supports fp16/fp32
For maximum memory savings during inference (no cache), use use_cache=False, but expect ~10x slower inference.
Credits
- bitsandbytes - Original CUDA implementation
- LLM.int8() - Paper by Tim Dettmers et al.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mps_bitsandbytes-0.1.2.tar.gz
(11.8 kB
view details)
File details
Details for the file mps_bitsandbytes-0.1.2.tar.gz.
File metadata
- Download URL: mps_bitsandbytes-0.1.2.tar.gz
- Upload date:
- Size: 11.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fdb9ee2c63429d162877de42f3a8e6ce0b9f82e5475247b2b5ec3e2388251d1b
|
|
| MD5 |
06d4545f6d2711291336f6feccf160fa
|
|
| BLAKE2b-256 |
2ac8033f2b298f48a8ccd1fb2140b9c48ed952211daa34519fad4a3bc96b6fcf
|