Skip to main content

NF4/FP4/FP8/INT8 quantization for PyTorch on Apple Silicon with Metal GPU acceleration

Project description

MPS BitsAndBytes

Real 4-bit and 8-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4).

Full bitsandbytes-compatible API with Metal GPU acceleration for running large models on your Mac.

Features

Format Bits Memory Savings Best For
NF4 4-bit ~75% LLM weights (normally distributed)
FP4 4-bit ~75% Alternative with better dynamic range
FP8 E4M3 8-bit ~50% Better precision than INT8
INT8 8-bit ~50% General purpose

Plus:

  • Metal GPU kernels - Fused dequant+matmul, no Python overhead
  • Double quantization - Extra ~10% savings on scales
  • 8-bit Optimizers - Adam8bit, AdamW8bit, Lion8bit, SGD8bit
  • Paged Optimizers - CPU offloading for larger models
  • Quantized Embeddings - Embedding4bit, Embedding8bit
  • Sparse Operations - spmm_coo, spmm_coo_int8
  • LLM.int8 - OutlierAwareLinear with col+row quantization
  • HuggingFace compatible - BitsAndBytesConfig API works out of the box
  • QLoRA training - Freeze quantized weights, train LoRA adapters

Installation

pip install mps-bitsandbytes

Or from source:

git clone https://github.com/mpsops/mps-bitsandbytes
cd mps-bitsandbytes
pip install -e .

Quick Start

4-bit Quantization (NF4 - Recommended for LLMs)

import torch
from mps_bitsandbytes import Linear4bit, BitsAndBytesConfig, quantize_model

# Convert a single layer
linear = torch.nn.Linear(4096, 4096).half().to('mps')
linear_4bit = Linear4bit.from_linear(linear)  # NF4 by default

# Or use FP4
linear_fp4 = Linear4bit.from_linear(linear, quant_type='fp4')

# Quantize entire model
config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model = quantize_model(your_model, quantization_config=config, device='mps')

8-bit Quantization (FP8 or INT8)

from mps_bitsandbytes import Linear8bit, LinearFP8

# INT8 (traditional)
linear_int8 = Linear8bit.from_linear(linear)

# FP8 E4M3 (better precision)
linear_fp8 = LinearFP8.from_linear(linear)

8-bit Optimizers

Memory-efficient optimizers that store momentum/variance in 8-bit:

from mps_bitsandbytes import Adam8bit, AdamW8bit, Lion8bit, SGD8bit

# Drop-in replacement for torch optimizers
optimizer = Adam8bit(model.parameters(), lr=1e-3)
optimizer = AdamW8bit(model.parameters(), lr=1e-3, weight_decay=0.01)
optimizer = Lion8bit(model.parameters(), lr=1e-4)
optimizer = SGD8bit(model.parameters(), lr=0.1, momentum=0.9)

Paged Optimizers

Offload optimizer states to CPU for training larger models:

from mps_bitsandbytes import PagedAdam, PagedAdamW, PagedLion

# States are stored on CPU, copied to GPU during step()
optimizer = PagedAdamW(model.parameters(), lr=1e-3, page_to_cpu=True)

Quantized Embeddings

Reduce embedding table memory by 50-75%:

from mps_bitsandbytes import Embedding4bit, Embedding8bit, EmbeddingNF4, EmbeddingFP4

# Convert existing embedding
embed = torch.nn.Embedding(50000, 4096).half().to('mps')
embed_4bit = Embedding4bit.from_embedding(embed)  # NF4 by default
embed_fp4 = EmbeddingFP4.from_embedding(embed)    # FP4
embed_8bit = Embedding8bit.from_embedding(embed)  # INT8

Functional API

from mps_bitsandbytes import (
    # 4-bit
    quantize_nf4, dequantize_nf4, matmul_nf4,
    quantize_fp4, dequantize_fp4, matmul_fp4,
    # 8-bit
    quantize_fp8_e4m3, dequantize_fp8_e4m3, matmul_fp8_e4m3,
    quantize_rowwise, dequantize_rowwise, matmul_int8,
    # Col+Row INT8 (LLM.int8 style)
    quantize_colrow, dequantize_colrow, matmul_colrow,
    # Double quantization
    double_quant, dequant_absmax,
    # Sparse
    spmm_coo, spmm_coo_int8, sparse_coo_from_dense, quantize_sparse_coo,
)

# NF4
weight = torch.randn(4096, 4096, device='mps', dtype=torch.float16)
packed, absmax = quantize_nf4(weight, block_size=64)
output = matmul_nf4(input, packed, absmax)

# Double quantization (quantize the scales too)
absmax_quant, absmax_scales = double_quant(absmax)

Memory Savings

Model FP16 INT8/FP8 NF4/FP4
7B params 14 GB 7 GB 3.5 GB
13B params 26 GB 13 GB 6.5 GB
70B params 140 GB 70 GB 35 GB

HuggingFace Integration

from transformers import AutoModelForCausalLM
from mps_bitsandbytes import BitsAndBytesConfig, quantize_model

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
)

config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model = quantize_model(model, quantization_config=config, device='mps')

QLoRA Training

from mps_bitsandbytes import BitsAndBytesConfig, quantize_model, Adam8bit
from peft import get_peft_model, LoraConfig

# Load in 4-bit
config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
model = AutoModelForCausalLM.from_pretrained("model_name", torch_dtype=torch.float16)
model = quantize_model(model, quantization_config=config, device='mps')

# Add LoRA adapters (train in fp16 while base stays quantized)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)

# Use 8-bit optimizer for extra memory savings
optimizer = Adam8bit(model.parameters(), lr=1e-4)
trainer.train()

API Reference

Linear Modules

Class Format Use Case
Linear4bit NF4 or FP4 LLM inference, QLoRA
Linear8bit INT8 General quantization
LinearFP8 FP8 E4M3 Better precision 8-bit
OutlierAwareLinear INT8 + FP16 LLM.int8 mixed precision
SwitchBackLinear INT8 Training with quantized forward

Embedding Modules

Class Format Memory Savings
Embedding4bit NF4 (default) ~75%
EmbeddingNF4 NF4 ~75%
EmbeddingFP4 FP4 ~75%
Embedding8bit INT8 ~50%

Optimizers

Class Description
Adam8bit Adam with 8-bit states
AdamW8bit AdamW with 8-bit states
Lion8bit Lion optimizer with 8-bit momentum
SGD8bit SGD with 8-bit momentum
PagedAdam Adam with CPU offloading
PagedAdamW AdamW with CPU offloading
PagedLion Lion with CPU offloading

Functional API

4-bit (NF4/FP4):

  • quantize_nf4(tensor, block_size=64) / quantize_fp4(...)
  • dequantize_nf4(packed, absmax, ...) / dequantize_fp4(...)
  • matmul_nf4(input, weight_packed, weight_absmax, bias=None) / matmul_fp4(...)

8-bit:

  • quantize_fp8_e4m3(tensor) - FP8 quantization
  • quantize_rowwise(tensor) - INT8 row-wise quantization
  • quantize_colrow(tensor) - INT8 col+row quantization (LLM.int8)
  • matmul_fp8_e4m3(...) / matmul_int8(...) / matmul_colrow(...)

Double Quantization:

  • double_quant(absmax, double_quant_block=256) - Quantize scales
  • dequant_absmax(absmax_quant, absmax_scales) - Restore scales

Sparse Operations:

  • sparse_coo_from_dense(tensor) - Convert to COO format
  • spmm_coo(row_idx, col_idx, values, dense, rows, cols) - Sparse matmul
  • spmm_coo_int8(...) - INT8 sparse matmul
  • quantize_sparse_coo(row_idx, col_idx, values) - Quantize sparse values

Utilities:

  • is_available() - Check MPS availability
  • has_native_kernels() - Check Metal kernels loaded
  • get_memory_footprint(model) - Calculate memory usage

Comparison with bitsandbytes

Feature bitsandbytes (CUDA) mps-bitsandbytes
NF4/FP4 CUDA Metal
INT8/FP8 CUDA Metal
Double quant CUDA Metal
8-bit Optimizers CUDA Pure PyTorch
Paged Optimizers CUDA Pure PyTorch
Quantized Embeddings CUDA Pure PyTorch
Sparse matmul CUDA Pure PyTorch
LLM.int8 (col+row) CUDA Pure PyTorch
Platform NVIDIA Apple Silicon

Demo

# Chat with a quantized LLM
python demo/chat.py

License

MIT

Credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_bitsandbytes-0.4.7.tar.gz (61.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_bitsandbytes-0.4.7-cp314-cp314-macosx_15_0_arm64.whl (160.3 kB view details)

Uploaded CPython 3.14macOS 15.0+ ARM64

File details

Details for the file mps_bitsandbytes-0.4.7.tar.gz.

File metadata

  • Download URL: mps_bitsandbytes-0.4.7.tar.gz
  • Upload date:
  • Size: 61.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_bitsandbytes-0.4.7.tar.gz
Algorithm Hash digest
SHA256 102f7986645a39dec6e3be17483e531961c8af30107dcbb653e574caa7ad4cd1
MD5 f77819b09da0e0bd2c0791ed1d88981a
BLAKE2b-256 071b637b549c693016a7c15e61736faa8aae31351bed42ead51361a958f1936f

See more details on using hashes here.

File details

Details for the file mps_bitsandbytes-0.4.7-cp314-cp314-macosx_15_0_arm64.whl.

File metadata

File hashes

Hashes for mps_bitsandbytes-0.4.7-cp314-cp314-macosx_15_0_arm64.whl
Algorithm Hash digest
SHA256 392dfbf91e66ae983abea166ed948665905d55159ac9febc8bbb95dd0cdaf4da
MD5 b96ee6b0f91559d140f85d9314278d3b
BLAKE2b-256 f3d6c97170891eccd7e55c905ca0b307d2e88f6032fad682fe25bd036f9625c0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page