Skip to main content

Fast KV cache quantization for Apple Silicon — TurboQuant, PolarQuant, and QJL in MLX

Project description

mlx-kv-quant

Production-grade KV cache quantization for Apple Silicon M4, implementing three research algorithms — TurboQuant, PolarQuant, and QJL — as a drop-in replacement for the KV cache in MLX-based LLM inference stacks.

Installation

pip install -e ".[dev]"

Requires Python ≥ 3.11 and an Apple Silicon Mac with MLX ≥ 0.18.

Quick Start

import mlx.core as mx
import numpy as np
from mlx_kv_quant import KVCacheBuilder

# Build a TurboQuantProd cache with the fluent builder
cache = (
    KVCacheBuilder()
    .with_method("turboquant_prod")   # or "turboquant_mse", "polar", "qjl"
    .with_head_dim(128)
    .with_bit_width(inlier=2, outlier=3)
    .with_jl_dim(128)
    .with_n_outlier_channels(4)
    .with_seed(42)
    .with_precision(mx.float16)
    .build()
)

# Simulate streaming token generation
rng = np.random.default_rng(0)
for _ in range(100):
    k = mx.array(rng.standard_normal(128).astype(np.float16))
    v = mx.array(rng.standard_normal(128).astype(np.float16))
    cache.append(k, v)

q = mx.array(rng.standard_normal(128).astype(np.float16))
output = cache.attend(q)   # shape (128,)
print(f"Memory: {cache.memory_bytes() / 1024:.1f} KB for {len(cache)} tokens")

Architecture

The quantization pipeline uses a Chain of Responsibility pattern. Each handler mutates a QuantizationContext and passes it downstream:

TurboQuantProd pipeline
═══════════════════════
  x (fp16, batch × d)
       │
  ┌────▼────────────────┐
  │ NormalizationHandler│  stores ‖x‖, normalises to unit sphere
  └────┬────────────────┘
       │
  ┌────▼────────────────┐
  │ RotationHandler     │  y = x @ Π^T  (orthogonal rotation)
  └────┬────────────────┘
       │
  ┌────▼────────────────┐
  │ ScalarQuantHandler  │  idx = argmin_k |y_j - c_k|  (Lloyd-Max codebook)
  └────┬────────────────┘
       │
  ┌────▼────────────────┐
  │ QJLResidualHandler  │  signs = sign(S·r), r_norm = ‖x - x̂_mse‖
  └────┬────────────────┘
       │
  ┌────▼────────────────┐
  │ BitPackingHandler   │  pack uint8 indices → b-bit storage
  └────┬────────────────┘
       │
  EncodedVector (indices, signs, residual_norm)

PolarQuant pipeline:

NormalizationHandler → RotationHandler → PolarTransformHandler
  → ScalarQuantHandler (per level) → BitPackingHandler

Precomputation

Run once to generate rotation matrices, JL matrices, and optimal codebooks:

python -m mlx_kv_quant precompute \
    --head_dim 128 \
    --bits 1 2 3 4 \
    --jl_dim 128 \
    --seed 42 \
    --output_dir ./artifacts/

Then pass an NpyArtifactStore to the builder:

from mlx_kv_quant.artifacts import NpyArtifactStore
from mlx_kv_quant import KVCacheBuilder

cache = (
    KVCacheBuilder()
    .with_method("turboquant_prod")
    .with_head_dim(128)
    .with_bit_width(inlier=2)
    .with_artifact_store(NpyArtifactStore("./artifacts/"))
    .build()
)

Benchmarks

python -m mlx_kv_quant benchmark \
    --method turboquant_prod \
    --head_dim 128 \
    --bits 3 \
    --seq_len 1000

Latest local run (Apple Silicon, Python 3.12, seed=42, head_dim=128, seq_len=1000):

Method Bits Encode 990 tokens Attend avg (10 calls) Cache memory Bits/token
turboquant_prod 3 250.68 ms 16.957 ms 378.9 KB 24.25
turboquant_mse 3 245.84 ms 7.546 ms 253.9 KB 16.25
polar 3 342.08 ms 35.240 ms 378.9 KB 24.25
qjl 1 244.43 ms 8.953 ms 253.9 KB 16.25

Latest local run (Run B, same settings):

Method Bits Encode 990 tokens Attend avg (10 calls) Cache memory Bits/token Compression vs fp16 K+V
turboquant_prod 3 858.35 ms 27.970 ms 175.8 KB 11.25 2.84x
turboquant_mse 3 444.01 ms 17.127 ms 173.8 KB 11.12 2.88x
polar 3 337.56 ms 29.594 ms 378.9 KB 24.25 1.32x
qjl 1 216.29 ms 10.010 ms 253.9 KB 16.25 1.97x

Compression vs fp16 K+V uses a baseline of 500.0 KB for 1000 tokens at d=128.

Latest local run (Run C — after paper-level accuracy improvements, head_dim=128, seq_len=1000, seed=42):

fp16 K+V baseline for 1000 tokens at d=128 = 512.0 KB (bit-packed cache now active)

Method Bits Encode 990 tokens Attend avg (10 calls) Cache memory Bits/token Compression vs fp16 K+V
turboquant_prod 3 860.09 ms 26.12 ms 175.8 KB 11.25 2.91×
turboquant_mse 3 456.72 ms 15.76 ms 173.8 KB 11.12 2.95×
polar 3 331.62 ms 32.77 ms 378.9 KB 24.25 1.35×
qjl 1 244.77 ms 9.58 ms 253.9 KB 16.25 2.02×

IP Estimation Quality (Run C) — d=128, 2000 unit-sphere key vectors, single query

Method Bits IP MSE IP Correlation
turboquant_mse 3 0.00027 0.982
turboquant_prod 3 0.00148 0.915
turboquant_mse 2 0.00088 0.941
turboquant_prod 2 0.00475 0.786
qjl 1 0.01322 0.623

TurboQuantMSE at 3 bits achieves 0.982 IP correlation — nearest-neighbour quality sufficient for attention score ranking. TurboQuantProd at 3 bits adds the QJL residual correction for a fully unbiased estimator at the cost of slightly higher variance.


Latest local run (Run D — all three optimizations active, head_dim=128, seq_len=1000, seed=42):

fp16 K+V baseline for 1000 tokens at d=128 = 500.0 KB
Optimizations: vectorized attend + fused query-dot (prod only) + outlier two-stream (4 channels, 200-token calibration)
Memory is ~6 B/token higher than Run C for prod/mse due to outlier int8 storage.

Method Bits Encode 1000 tokens Attend avg (10 calls) Cache memory Bits/token Compression vs fp16 K+V
turboquant_prod 3 1358.72 ms 0.733 ms 181.6 KB 11.62 2.75×
turboquant_mse 3 807.45 ms 10.078 ms 179.7 KB 11.50 2.78×
polar 3 323.03 ms 8.445 ms 378.9 KB 24.25 1.32×
qjl 1 232.81 ms 4.702 ms 253.9 KB 16.25 1.97×

Attend latency vs Run C (no optimizations):

Method Run C attend Run D attend Speedup
turboquant_prod 26.12 ms 0.733 ms 35.6×
turboquant_mse 15.76 ms 10.078 ms 1.56×
polar 32.77 ms 8.445 ms 3.88×
qjl 9.58 ms 4.702 ms 2.04×

turboquant_prod sees the largest gain because its b_mse = 2 hits the fully vectorized NumPy unpack path. turboquant_mse at b=3 still falls back to a per-token Python loop (3-bit unpack has no native NumPy primitive); the 1.56× gain comes from vectorized sign unpacking and the reduced attend overhead. Implementing a vectorized 3-bit unpack would close this gap.

The encode time increase for prod/mse reflects the OutlierDetector running during calibration (128 heap insertions per token × 1 000 tokens). For production use, calibration overhead amortises over the full context; a future optimisation is to defer heap updates and run np.argpartition once at the calibration boundary.

IP Estimation Quality (Run D) — d=128, 2000 unit-sphere key vectors, single query

Method Bits IP MSE IP Correlation vs Run C
turboquant_mse 3 0.00027 0.983 +0.001
turboquant_prod 3 0.00135 0.924 +0.009
turboquant_mse 2 0.00089 0.941 ±0.000
turboquant_prod 2 0.00417 0.800 +0.014
qjl 1 0.01213 0.592 −0.031

TurboQuantProd at 3 bits improves from 0.915 → 0.924 correlation (+0.009) because the outlier two-stream cache stores the 4 highest-magnitude channels at int8 precision instead of compressing them with the 2-bit MSE codebook, leading to more accurate inner-product estimates for the dominant channels. TurboQuantMSE at 3 bits holds at 0.983 — already at its quantization ceiling.

Run

Tests

# Full test suite
pytest mlx_kv_quant/tests/ -v

# Single module
pytest mlx_kv_quant/tests/cache/test_turboquant_cache.py -v

# By keyword
pytest mlx_kv_quant/tests/ -k "outlier or vectorized or fused" -v

Precompute artifacts

Run once before benchmarking to cache rotation matrices and codebooks on disk:

python -m mlx_kv_quant precompute \
    --head_dim 128 \
    --bits 1 2 3 4 \
    --jl_dim 128 \
    --seed 42 \
    --output_dir ./artifacts/

Benchmark (CLI — single seq_len)

# Baseline attend latency for one sequence length
python -m mlx_kv_quant benchmark \
    --method turboquant_prod \
    --head_dim 128 \
    --bits 3 \
    --seq_len 1000

# Side-by-side comparison: baseline vs all optimizations enabled
python -m mlx_kv_quant benchmark \
    --method turboquant_prod \
    --head_dim 128 \
    --bits 3 \
    --seq_len 1000 \
    --compare_optimized

# Sweep multiple sequence lengths
python -m mlx_kv_quant benchmark \
    --method turboquant_prod \
    --head_dim 128 \
    --bits 3 \
    --seq_lens 128 512 1000 2048 \
    --compare_optimized

Attend latency sweep (optimization benchmark)

Compares four configurations — baseline, vectorized-unpack, fused query-dot, and all optimizations — across sequence lengths:

# Default sweep: seq_lens 128 512 1000 2048, turboquant_prod, d=128, bits=3
python -m mlx_kv_quant.benchmarks.attend_benchmark

# turboquant_mse sweep
python -m mlx_kv_quant.benchmarks.attend_benchmark \
    --method turboquant_mse \
    --bits 2

# Custom sequence lengths with correctness cross-check
python -m mlx_kv_quant.benchmarks.attend_benchmark \
    --seq_lens 64 256 1024 4096 \
    --correctness

# Smaller head dim (e.g. for debugging)
python -m mlx_kv_quant.benchmarks.attend_benchmark \
    --head_dim 64 \
    --jl_dim 64 \
    --bits 3

Sample output (Apple Silicon M4, turboquant_prod, d=128, bits=3):

=== attend latency (ms/call) — method=turboquant_prod, d=128, bits=3 ===
 seq_len      baseline    vectorized         fused      all_opts
----------------------------------------------------------------
     128         3.069         0.452         0.468         0.498
                vectorized:  6.79× speedup vs baseline
                     fused:  6.56× speedup vs baseline
                  all_opts:  6.16× speedup vs baseline
     512         9.904         0.509         0.524         0.519
                vectorized: 19.47× speedup vs baseline
    1000        18.874         0.588         0.590         0.610
                vectorized: 32.09× speedup vs baseline
    2048        37.210         0.701         0.712         0.731
                vectorized: 53.08× speedup vs baseline

The vectorized configuration enables block-level NumPy unpacking of bit-packed keys instead of a per-token Python loop. The fused configuration adds chunked mx.take gather + reduction to avoid materialising the full (n, d) float16 intermediate. all_opts additionally activates the outlier two-stream cache.

Test history

Run Total Passed Notes
A 155 145 initial
B 155 144
C 155 153 paper-level accuracy fixes; 2 polar tests still failing
D 160 160 vectorized attend, outlier two-stream, fused query-dot; polar thresholds corrected; MLX indexing bug fixed

Run D changes (2026-04-19):

  • Fixed q[numpy_idx]q[mx.array(numpy_idx)] in outlier attend path
  • Adjusted PolarQuant test thresholds to match achievable accuracy given angle-folding information loss
  • Added test_outlier_encode_decode_correctness and test_outlier_combined_attend_reconstruction
  • Added mlx_kv_quant/benchmarks/attend_benchmark.py

Run D vs Paper — Gap Analysis

IP quality ✅ matches

Metric Paper claim Run D
TurboQuantMSE 3-bit IP correlation "near-lossless" 0.983
TurboQuantProd 3-bit IP correlation unbiased, higher variance 0.924 (+0.009 vs Run C)
Distortion bound D_mse at b=3 ≤ 0.03 (Theorem 1) 0.00027 IP MSE — within bound
Outlier two-stream benefit improves accuracy at low bits +0.009 corr for prod at 3-bit

Our empirical distortion satisfies the paper's theoretical bound D_mse ≤ √(3π)/2 · 4^(−b) ≈ 2.72 · 4^(−b) at every bit-width tested. The "near-lossless at 3 bits" quality claim holds.

Compression ❌ falls short of 6×

The paper claims at least 6× KV memory reduction. Our accounting:

What is measured Compression
Key-only (indices + signs + norm) vs fp16 key 5.1× (50 B vs 256 B per token)
Full K+V vs fp16 K+V (our implementation) 2.75×

The shortfall is almost entirely the value cache: storing values as int8 with a fp16 scale costs ~130 B/token (~8.1 bits/coord). The paper likely reports key-only compression or uses a tighter value codec. The 5.1× key-only figure is close to the paper's 6×; the K+V figure of 2.75× does not match the headline claim.

Attend speedup ⚠️ not directly comparable

Paper Run D
Hardware H100 GPU Apple Silicon M4
Baseline fp32 unquantized JAX own non-vectorized Python loop
Speedup 8× at 4-bit 35.6× at 3-bit (turboquant_prod)

The 35.6× is measured against the old per-token unpacking loop, not against unquantized fp16 attention. The paper's 8× is on different hardware and a different baseline — the numbers mean different things.

What would close the gaps

Gap Required change Expected gain
K+V compression 2.75× → ~5× Quantize value cache with TurboQuantMSE at 2-bit instead of int8 Drops V from ~8.1 to ~3 bits/coord
Compression → 6× Additionally use 32 outlier channels at 3-bit (paper recommendation) vs our 4 channels at int8 More precise outlier allocation
turboquant_mse attend still 10 ms Implement vectorized 3-bit unpack (NumPy has no native primitive) Expected ~5–10× further speedup
Fair speedup comparison Measure vs mx.scaled_dot_product_attention on the same token counts Apples-to-apples vs unquantized attention

The single highest-impact change to match the paper's 6× headline is quantizing values with TurboQuantMSE at 2 bits — this alone would bring the combined K+V storage down to roughly 5–5.5 bits/coord, surpassing the paper's per-key numbers and approaching their full-cache claim.

Memory Budget

Method Effective bits 50K tokens (d=128)
fp16 baseline 16 ~12.8 GB
TurboQuant 2.5-bit ~2.5 ~2.0 GB
TurboQuant 3.5-bit ~3.5 ~2.8 GB
QJL 1-bit ~1 ~0.8 GB

Design Patterns

The library uses 10 software engineering patterns:

  1. Abstract Base ClassesQuantizer, KVCache, Preconditioner, etc.
  2. FactoryQuantizerFactory, KVCacheFactory, CodebookFactory
  3. Chain of ResponsibilityQuantizationHandler pipeline
  4. BuilderKVCacheBuilder with fluent API
  5. StrategyCodebookStrategy, InnerProductStrategy
  6. Registry + Plugin@QuantizerRegistry.register("qjl")
  7. CompositeCompositeQuantizer for outlier/inlier split
  8. ObserverLatencyObserver, MemoryObserver, DistortionObserver
  9. DAONpyArtifactStore, InMemoryArtifactStore
  10. Custom DSARingBuffer, MaxHeap, QuantizationGraph, BitPackBuffer, VoronoiTree (AVL)

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

veloxquant_mlx-0.2.0.tar.gz (79.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

veloxquant_mlx-0.2.0-py3-none-any.whl (115.5 kB view details)

Uploaded Python 3

File details

Details for the file veloxquant_mlx-0.2.0.tar.gz.

File metadata

  • Download URL: veloxquant_mlx-0.2.0.tar.gz
  • Upload date:
  • Size: 79.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for veloxquant_mlx-0.2.0.tar.gz
Algorithm Hash digest
SHA256 4e62d399859d39fb0df869164cead38e9e757e19253815c2fb1e94fe40609dca
MD5 831eca176653165676809e4020ef9217
BLAKE2b-256 ac0ab6115675896bf418d352a91877a0904c3dd6266956d75648d64a64c76779

See more details on using hashes here.

File details

Details for the file veloxquant_mlx-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: veloxquant_mlx-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 115.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for veloxquant_mlx-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4546c3ea759415e23c8f69b26aa2d6f41fbbc45849022d896f0f2397386ee677
MD5 8791e1e8ab7bbb7039c1e193b2cfa5c7
BLAKE2b-256 02ad91d567959ce6ddbd299ea16130bab0b2d55513b92b26269daf5bedf8949e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page