Skip to main content

8-bit and 4-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4)

Project description

MPS BitsAndBytes

8-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4).

50% memory savings for storing model weights, with no speed penalty using smart caching.

Features

  • Linear8bit: Drop-in replacement for nn.Linear with int8 weights
  • Smart caching: Dequantize once, run fast fp16 matmul (AMX accelerated)
  • QLoRA ready: Perfect for fine-tuning large models on Mac
  • Pure PyTorch: No custom kernels needed, works out of the box

Installation

pip install mps-bitsandbytes

Or from source:

git clone https://github.com/mpsops/mps-bitsandbytes
cd mps-bitsandbytes
pip install -e .

Quick Start

import torch
from mps_bitsandbytes import Linear8bit, quantize_model

# Convert existing model to 8-bit
model = YourModel().to('mps')
model = quantize_model(model, device='mps')

# Or convert individual layers
linear_8bit = Linear8bit.from_linear(some_linear_layer)

# Use normally - same API, 50% less memory for weights
output = model(input)

How It Works

  1. Storage: Weights stored as int8 (1 byte per param vs 2 bytes for fp16)
  2. First forward: Dequantize int8 → fp16, cache the result
  3. Subsequent forwards: Use cached fp16 weights, fast AMX matmul

This gives you:

  • 50% memory savings on disk and when loading weights
  • Same inference speed as fp16 (once cached)
  • Compatible with QLoRA training

Memory Savings

Model Size FP16 INT8 Savings
7B params 14 GB 7 GB 7 GB
13B params 26 GB 13 GB 13 GB
70B params 140 GB 70 GB 70 GB

Configuration

# Default: cache enabled (fast, uses memory during inference)
layer = Linear8bit.from_linear(linear, use_cache=True)

# Memory-constrained: no cache (slower, minimum memory)
layer = Linear8bit.from_linear(linear, use_cache=False)

# Clear cache to free memory
layer.clear_cache()

QLoRA Training

from mps_bitsandbytes import quantize_model
from peft import get_peft_model, LoraConfig

# Load model in 8-bit
model = AutoModelForCausalLM.from_pretrained("model_name")
model = quantize_model(model.to('mps'))

# Add LoRA adapters (these stay in fp16 for gradients)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)

# Train - base weights frozen in int8, LoRA in fp16
trainer.train()

Benchmarks

Tested on M1 Max, batch_size=32, hidden_dim=4096:

Method Forward Time Memory
FP16 1.08 ms 100 MB
INT8 (cached) 0.98 ms 50 MB + cache
INT8 (no cache) 9.65 ms 50 MB

Limitations

  • First forward is slower: Need to dequantize weights once
  • Cache uses memory: During inference, cached fp16 weights use extra memory
  • No int8 matmul acceleration: Apple Silicon AMX only supports fp16/fp32

For maximum memory savings during inference (no cache), use use_cache=False, but expect ~10x slower inference.

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_bitsandbytes-0.1.2.tar.gz (11.8 kB view details)

Uploaded Source

File details

Details for the file mps_bitsandbytes-0.1.2.tar.gz.

File metadata

  • Download URL: mps_bitsandbytes-0.1.2.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_bitsandbytes-0.1.2.tar.gz
Algorithm Hash digest
SHA256 fdb9ee2c63429d162877de42f3a8e6ce0b9f82e5475247b2b5ec3e2388251d1b
MD5 06d4545f6d2711291336f6feccf160fa
BLAKE2b-256 2ac8033f2b298f48a8ccd1fb2140b9c48ed952211daa34519fad4a3bc96b6fcf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page