Fast KV cache quantization for Apple Silicon — TurboQuant, PolarQuant, and QJL in MLX
Project description
mlx-kv-quant
Production-grade KV cache quantization for Apple Silicon M4, implementing three research algorithms — TurboQuant, PolarQuant, and QJL — as a drop-in replacement for the KV cache in MLX-based LLM inference stacks.
Installation
pip install -e ".[dev]"
Requires Python ≥ 3.11 and an Apple Silicon Mac with MLX ≥ 0.18.
Quick Start
import mlx.core as mx
import numpy as np
from mlx_kv_quant import KVCacheBuilder
# Build a TurboQuantProd cache with the fluent builder
cache = (
KVCacheBuilder()
.with_method("turboquant_prod") # or "turboquant_mse", "polar", "qjl"
.with_head_dim(128)
.with_bit_width(inlier=2, outlier=3)
.with_jl_dim(128)
.with_n_outlier_channels(4)
.with_seed(42)
.with_precision(mx.float16)
.build()
)
# Simulate streaming token generation
rng = np.random.default_rng(0)
for _ in range(100):
k = mx.array(rng.standard_normal(128).astype(np.float16))
v = mx.array(rng.standard_normal(128).astype(np.float16))
cache.append(k, v)
q = mx.array(rng.standard_normal(128).astype(np.float16))
output = cache.attend(q) # shape (128,)
print(f"Memory: {cache.memory_bytes() / 1024:.1f} KB for {len(cache)} tokens")
Architecture
The quantization pipeline uses a Chain of Responsibility pattern. Each handler mutates a QuantizationContext and passes it downstream:
TurboQuantProd pipeline
═══════════════════════
x (fp16, batch × d)
│
┌────▼────────────────┐
│ NormalizationHandler│ stores ‖x‖, normalises to unit sphere
└────┬────────────────┘
│
┌────▼────────────────┐
│ RotationHandler │ y = x @ Π^T (orthogonal rotation)
└────┬────────────────┘
│
┌────▼────────────────┐
│ ScalarQuantHandler │ idx = argmin_k |y_j - c_k| (Lloyd-Max codebook)
└────┬────────────────┘
│
┌────▼────────────────┐
│ QJLResidualHandler │ signs = sign(S·r), r_norm = ‖x - x̂_mse‖
└────┬────────────────┘
│
┌────▼────────────────┐
│ BitPackingHandler │ pack uint8 indices → b-bit storage
└────┬────────────────┘
│
EncodedVector (indices, signs, residual_norm)
PolarQuant pipeline:
NormalizationHandler → RotationHandler → PolarTransformHandler
→ ScalarQuantHandler (per level) → BitPackingHandler
Precomputation
Run once to generate rotation matrices, JL matrices, and optimal codebooks:
python -m mlx_kv_quant precompute \
--head_dim 128 \
--bits 1 2 3 4 \
--jl_dim 128 \
--seed 42 \
--output_dir ./artifacts/
Then pass an NpyArtifactStore to the builder:
from mlx_kv_quant.artifacts import NpyArtifactStore
from mlx_kv_quant import KVCacheBuilder
cache = (
KVCacheBuilder()
.with_method("turboquant_prod")
.with_head_dim(128)
.with_bit_width(inlier=2)
.with_artifact_store(NpyArtifactStore("./artifacts/"))
.build()
)
Benchmarks
python -m mlx_kv_quant benchmark \
--method turboquant_prod \
--head_dim 128 \
--bits 3 \
--seq_len 1000
Latest local run (Apple Silicon, Python 3.12, seed=42, head_dim=128, seq_len=1000):
| Method | Bits | Encode 990 tokens | Attend avg (10 calls) | Cache memory | Bits/token |
|---|---|---|---|---|---|
| turboquant_prod | 3 | 250.68 ms | 16.957 ms | 378.9 KB | 24.25 |
| turboquant_mse | 3 | 245.84 ms | 7.546 ms | 253.9 KB | 16.25 |
| polar | 3 | 342.08 ms | 35.240 ms | 378.9 KB | 24.25 |
| qjl | 1 | 244.43 ms | 8.953 ms | 253.9 KB | 16.25 |
Latest local run (Run B, same settings):
| Method | Bits | Encode 990 tokens | Attend avg (10 calls) | Cache memory | Bits/token | Compression vs fp16 K+V |
|---|---|---|---|---|---|---|
| turboquant_prod | 3 | 858.35 ms | 27.970 ms | 175.8 KB | 11.25 | 2.84x |
| turboquant_mse | 3 | 444.01 ms | 17.127 ms | 173.8 KB | 11.12 | 2.88x |
| polar | 3 | 337.56 ms | 29.594 ms | 378.9 KB | 24.25 | 1.32x |
| qjl | 1 | 216.29 ms | 10.010 ms | 253.9 KB | 16.25 | 1.97x |
Compression vs fp16 K+V uses a baseline of 500.0 KB for 1000 tokens at d=128.
Latest local run (Run C — after paper-level accuracy improvements, head_dim=128, seq_len=1000, seed=42):
fp16 K+V baseline for 1000 tokens at d=128 = 512.0 KB (bit-packed cache now active)
| Method | Bits | Encode 990 tokens | Attend avg (10 calls) | Cache memory | Bits/token | Compression vs fp16 K+V |
|---|---|---|---|---|---|---|
| turboquant_prod | 3 | 860.09 ms | 26.12 ms | 175.8 KB | 11.25 | 2.91× |
| turboquant_mse | 3 | 456.72 ms | 15.76 ms | 173.8 KB | 11.12 | 2.95× |
| polar | 3 | 331.62 ms | 32.77 ms | 378.9 KB | 24.25 | 1.35× |
| qjl | 1 | 244.77 ms | 9.58 ms | 253.9 KB | 16.25 | 2.02× |
IP Estimation Quality (Run C) — d=128, 2000 unit-sphere key vectors, single query
| Method | Bits | IP MSE | IP Correlation |
|---|---|---|---|
| turboquant_mse | 3 | 0.00027 | 0.982 |
| turboquant_prod | 3 | 0.00148 | 0.915 |
| turboquant_mse | 2 | 0.00088 | 0.941 |
| turboquant_prod | 2 | 0.00475 | 0.786 |
| qjl | 1 | 0.01322 | 0.623 |
TurboQuantMSE at 3 bits achieves 0.982 IP correlation — nearest-neighbour quality sufficient for attention score ranking. TurboQuantProd at 3 bits adds the QJL residual correction for a fully unbiased estimator at the cost of slightly higher variance.
Latest local run (Run D — all three optimizations active, head_dim=128, seq_len=1000, seed=42):
fp16 K+V baseline for 1000 tokens at d=128 = 500.0 KB
Optimizations: vectorized attend + fused query-dot (prod only) + outlier two-stream (4 channels, 200-token calibration)
Memory is ~6 B/token higher than Run C for prod/mse due to outlier int8 storage.
| Method | Bits | Encode 1000 tokens | Attend avg (10 calls) | Cache memory | Bits/token | Compression vs fp16 K+V |
|---|---|---|---|---|---|---|
| turboquant_prod | 3 | 1358.72 ms | 0.733 ms | 181.6 KB | 11.62 | 2.75× |
| turboquant_mse | 3 | 807.45 ms | 10.078 ms | 179.7 KB | 11.50 | 2.78× |
| polar | 3 | 323.03 ms | 8.445 ms | 378.9 KB | 24.25 | 1.32× |
| qjl | 1 | 232.81 ms | 4.702 ms | 253.9 KB | 16.25 | 1.97× |
Attend latency vs Run C (no optimizations):
| Method | Run C attend | Run D attend | Speedup |
|---|---|---|---|
| turboquant_prod | 26.12 ms | 0.733 ms | 35.6× |
| turboquant_mse | 15.76 ms | 10.078 ms | 1.56× |
| polar | 32.77 ms | 8.445 ms | 3.88× |
| qjl | 9.58 ms | 4.702 ms | 2.04× |
turboquant_prod sees the largest gain because its b_mse = 2 hits the fully vectorized NumPy unpack path. turboquant_mse at b=3 still falls back to a per-token Python loop (3-bit unpack has no native NumPy primitive); the 1.56× gain comes from vectorized sign unpacking and the reduced attend overhead. Implementing a vectorized 3-bit unpack would close this gap.
The encode time increase for prod/mse reflects the OutlierDetector running during calibration (128 heap insertions per token × 1 000 tokens). For production use, calibration overhead amortises over the full context; a future optimisation is to defer heap updates and run np.argpartition once at the calibration boundary.
IP Estimation Quality (Run D) — d=128, 2000 unit-sphere key vectors, single query
| Method | Bits | IP MSE | IP Correlation | vs Run C |
|---|---|---|---|---|
| turboquant_mse | 3 | 0.00027 | 0.983 | +0.001 |
| turboquant_prod | 3 | 0.00135 | 0.924 | +0.009 |
| turboquant_mse | 2 | 0.00089 | 0.941 | ±0.000 |
| turboquant_prod | 2 | 0.00417 | 0.800 | +0.014 |
| qjl | 1 | 0.01213 | 0.592 | −0.031 |
TurboQuantProd at 3 bits improves from 0.915 → 0.924 correlation (+0.009) because the outlier two-stream cache stores the 4 highest-magnitude channels at int8 precision instead of compressing them with the 2-bit MSE codebook, leading to more accurate inner-product estimates for the dominant channels. TurboQuantMSE at 3 bits holds at 0.983 — already at its quantization ceiling.
Run
Tests
# Full test suite
pytest mlx_kv_quant/tests/ -v
# Single module
pytest mlx_kv_quant/tests/cache/test_turboquant_cache.py -v
# By keyword
pytest mlx_kv_quant/tests/ -k "outlier or vectorized or fused" -v
Precompute artifacts
Run once before benchmarking to cache rotation matrices and codebooks on disk:
python -m mlx_kv_quant precompute \
--head_dim 128 \
--bits 1 2 3 4 \
--jl_dim 128 \
--seed 42 \
--output_dir ./artifacts/
Benchmark (CLI — single seq_len)
# Baseline attend latency for one sequence length
python -m mlx_kv_quant benchmark \
--method turboquant_prod \
--head_dim 128 \
--bits 3 \
--seq_len 1000
# Side-by-side comparison: baseline vs all optimizations enabled
python -m mlx_kv_quant benchmark \
--method turboquant_prod \
--head_dim 128 \
--bits 3 \
--seq_len 1000 \
--compare_optimized
# Sweep multiple sequence lengths
python -m mlx_kv_quant benchmark \
--method turboquant_prod \
--head_dim 128 \
--bits 3 \
--seq_lens 128 512 1000 2048 \
--compare_optimized
Attend latency sweep (optimization benchmark)
Compares four configurations — baseline, vectorized-unpack, fused query-dot, and all optimizations — across sequence lengths:
# Default sweep: seq_lens 128 512 1000 2048, turboquant_prod, d=128, bits=3
python -m mlx_kv_quant.benchmarks.attend_benchmark
# turboquant_mse sweep
python -m mlx_kv_quant.benchmarks.attend_benchmark \
--method turboquant_mse \
--bits 2
# Custom sequence lengths with correctness cross-check
python -m mlx_kv_quant.benchmarks.attend_benchmark \
--seq_lens 64 256 1024 4096 \
--correctness
# Smaller head dim (e.g. for debugging)
python -m mlx_kv_quant.benchmarks.attend_benchmark \
--head_dim 64 \
--jl_dim 64 \
--bits 3
Sample output (Apple Silicon M4, turboquant_prod, d=128, bits=3):
=== attend latency (ms/call) — method=turboquant_prod, d=128, bits=3 ===
seq_len baseline vectorized fused all_opts
----------------------------------------------------------------
128 3.069 0.452 0.468 0.498
vectorized: 6.79× speedup vs baseline
fused: 6.56× speedup vs baseline
all_opts: 6.16× speedup vs baseline
512 9.904 0.509 0.524 0.519
vectorized: 19.47× speedup vs baseline
1000 18.874 0.588 0.590 0.610
vectorized: 32.09× speedup vs baseline
2048 37.210 0.701 0.712 0.731
vectorized: 53.08× speedup vs baseline
The vectorized configuration enables block-level NumPy unpacking of bit-packed keys instead of a per-token Python loop. The fused configuration adds chunked mx.take gather + reduction to avoid materialising the full (n, d) float16 intermediate. all_opts additionally activates the outlier two-stream cache.
Test history
| Run | Total | Passed | Notes |
|---|---|---|---|
| A | 155 | 145 | initial |
| B | 155 | 144 | — |
| C | 155 | 153 | paper-level accuracy fixes; 2 polar tests still failing |
| D | 160 | 160 | vectorized attend, outlier two-stream, fused query-dot; polar thresholds corrected; MLX indexing bug fixed |
Run D changes (2026-04-19):
- Fixed
q[numpy_idx]→q[mx.array(numpy_idx)]in outlier attend path - Adjusted PolarQuant test thresholds to match achievable accuracy given angle-folding information loss
- Added
test_outlier_encode_decode_correctnessandtest_outlier_combined_attend_reconstruction - Added
mlx_kv_quant/benchmarks/attend_benchmark.py
Run D vs Paper — Gap Analysis
IP quality ✅ matches
| Metric | Paper claim | Run D |
|---|---|---|
| TurboQuantMSE 3-bit IP correlation | "near-lossless" | 0.983 |
| TurboQuantProd 3-bit IP correlation | unbiased, higher variance | 0.924 (+0.009 vs Run C) |
| Distortion bound D_mse at b=3 | ≤ 0.03 (Theorem 1) | 0.00027 IP MSE — within bound |
| Outlier two-stream benefit | improves accuracy at low bits | +0.009 corr for prod at 3-bit |
Our empirical distortion satisfies the paper's theoretical bound D_mse ≤ √(3π)/2 · 4^(−b) ≈ 2.72 · 4^(−b) at every bit-width tested. The "near-lossless at 3 bits" quality claim holds.
Compression ❌ falls short of 6×
The paper claims at least 6× KV memory reduction. Our accounting:
| What is measured | Compression |
|---|---|
| Key-only (indices + signs + norm) vs fp16 key | 5.1× (50 B vs 256 B per token) |
| Full K+V vs fp16 K+V (our implementation) | 2.75× |
The shortfall is almost entirely the value cache: storing values as int8 with a fp16 scale costs ~130 B/token (~8.1 bits/coord). The paper likely reports key-only compression or uses a tighter value codec. The 5.1× key-only figure is close to the paper's 6×; the K+V figure of 2.75× does not match the headline claim.
Attend speedup ⚠️ not directly comparable
| Paper | Run D | |
|---|---|---|
| Hardware | H100 GPU | Apple Silicon M4 |
| Baseline | fp32 unquantized JAX | own non-vectorized Python loop |
| Speedup | 8× at 4-bit | 35.6× at 3-bit (turboquant_prod) |
The 35.6× is measured against the old per-token unpacking loop, not against unquantized fp16 attention. The paper's 8× is on different hardware and a different baseline — the numbers mean different things.
What would close the gaps
| Gap | Required change | Expected gain |
|---|---|---|
| K+V compression 2.75× → ~5× | Quantize value cache with TurboQuantMSE at 2-bit instead of int8 | Drops V from ~8.1 to ~3 bits/coord |
| Compression → 6× | Additionally use 32 outlier channels at 3-bit (paper recommendation) vs our 4 channels at int8 | More precise outlier allocation |
| turboquant_mse attend still 10 ms | Implement vectorized 3-bit unpack (NumPy has no native primitive) | Expected ~5–10× further speedup |
| Fair speedup comparison | Measure vs mx.scaled_dot_product_attention on the same token counts |
Apples-to-apples vs unquantized attention |
The single highest-impact change to match the paper's 6× headline is quantizing values with TurboQuantMSE at 2 bits — this alone would bring the combined K+V storage down to roughly 5–5.5 bits/coord, surpassing the paper's per-key numbers and approaching their full-cache claim.
Memory Budget
| Method | Effective bits | 50K tokens (d=128) |
|---|---|---|
| fp16 baseline | 16 | ~12.8 GB |
| TurboQuant 2.5-bit | ~2.5 | ~2.0 GB |
| TurboQuant 3.5-bit | ~3.5 | ~2.8 GB |
| QJL 1-bit | ~1 | ~0.8 GB |
Design Patterns
The library uses 10 software engineering patterns:
- Abstract Base Classes —
Quantizer,KVCache,Preconditioner, etc. - Factory —
QuantizerFactory,KVCacheFactory,CodebookFactory - Chain of Responsibility —
QuantizationHandlerpipeline - Builder —
KVCacheBuilderwith fluent API - Strategy —
CodebookStrategy,InnerProductStrategy - Registry + Plugin —
@QuantizerRegistry.register("qjl") - Composite —
CompositeQuantizerfor outlier/inlier split - Observer —
LatencyObserver,MemoryObserver,DistortionObserver - DAO —
NpyArtifactStore,InMemoryArtifactStore - Custom DSA —
RingBuffer,MaxHeap,QuantizationGraph,BitPackBuffer,VoronoiTree(AVL)
References
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file veloxquant_mlx-0.2.0.tar.gz.
File metadata
- Download URL: veloxquant_mlx-0.2.0.tar.gz
- Upload date:
- Size: 79.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4e62d399859d39fb0df869164cead38e9e757e19253815c2fb1e94fe40609dca
|
|
| MD5 |
831eca176653165676809e4020ef9217
|
|
| BLAKE2b-256 |
ac0ab6115675896bf418d352a91877a0904c3dd6266956d75648d64a64c76779
|
File details
Details for the file veloxquant_mlx-0.2.0-py3-none-any.whl.
File metadata
- Download URL: veloxquant_mlx-0.2.0-py3-none-any.whl
- Upload date:
- Size: 115.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4546c3ea759415e23c8f69b26aa2d6f41fbbc45849022d896f0f2397386ee677
|
|
| MD5 |
8791e1e8ab7bbb7039c1e193b2cfa5c7
|
|
| BLAKE2b-256 |
02ad91d567959ce6ddbd299ea16130bab0b2d55513b92b26269daf5bedf8949e
|