Skip to main content

3-D Stacked-Plane KV Cache Quantizer — defensive prior art publication

Project description

PrismKV: 3-D Stacked-Plane KV Cache Quantization

CI

First published: 2026-03-30 Author: Dan Hicks · github.com/danhicks96 License: Apache-2.0 Version: 1.0.0 Status: Defensive prior-art publication. All ideas herein are released under Apache-2.0.


The Idea

Large language models cache key (K) and value (V) tensors for every previously seen token — the "KV cache." At long context lengths this cache dominates GPU memory. Recent work (Google's TurboQuant, ICLR 2026) showed that quantizing KV vectors to 3–4 bits using 2-D polar coordinates after a random rotation achieves near-lossless compression at 6× memory reduction.

PrismKV extends this to 3-D.

TurboQuant groups each d-dimensional KV vector into d/2 independent pairs (x, y) and quantizes each pair in polar form (r, θ). Each pair is quantized without context from its neighbors. This is optimal for isotropic Gaussian data but misses cross-dimensional correlations that real KV distributions exhibit.

PrismKV introduces a conditional stacked-plane structure:

  • Group dimensions into triplets (z, x, y) instead of pairs
  • Coarsely quantize the z coordinate into B_z bins → index i_z
  • Use i_z to condition the 2-D polar quantization of (x, y) — selecting a per-z-slice codebook
  • This creates a 3-D quantization cell: a wedge of polar space at a specific z level

The result is a hierarchical encoding that captures relationships between the three coordinates. At the same bits-per-dimension budget (e.g., B_z=4, B_r=4, B_θ=4 → 4.0 bits/dim), the conditional structure allows per-slice codebook adaptation that flat 2-D schemes cannot express.


The Math

Notation

v ∈ R^d         — a rotated KV vector (after global rotation R)
d = 3 * m       — dim must be divisible by 3; m = number of triplet groups
B_z, B_r, B_θ  — bits allocated to z, radius, and angle
C_z = 2^B_z     — number of z-bins
C_r = 2^B_r     — number of radius bins
C_θ = 2^B_θ     — number of angle bins

Step 0 — Global Rotation (same as TurboQuant)

v_rot = R @ v

R is a (d, d) random orthogonal matrix (QR decomposition of a seeded Gaussian draw). This spreads energy uniformly across dimensions, making coordinates approximately independent — a prerequisite for efficient scalar quantization.

Step 1 — Triplet Extraction

After rotation, index the d dimensions as:

z-dim for group k:  index 3k       (k = 0, 1, ..., m-1)
x-dim for group k:  index 3k + 1
y-dim for group k:  index 3k + 2

No dimension is shared between groups (no overlapping). Each group k gives a triplet (z_k, x_k, y_k).

Step 2 — Coarse z Quantization

Δ_z = (z_max - z_min) / C_z
i_z = floor((z - z_min) / Δ_z)  ∈ {0, ..., C_z - 1}

z_min, z_max are set conservatively to ±sqrt(d) (or tightened via calibrate()).

Step 3 — Conditional 2-D Polar Quantization

Convert (x, y) to polar form:

r     = sqrt(x^2 + y^2)
θ     = atan2(y, x)          ∈ (-π, π]

Quantize uniformly (v1 uses the same table for all z-slices; per-slice learned tables are v2):

i_r     = round(r / r_max * (C_r - 1))          ∈ {0, ..., C_r - 1}
i_θ     = round((θ + π) / (2π) * (C_θ - 1))    ∈ {0, ..., C_θ - 1}

Step 4 — Packing

code = (i_z << (B_r + B_θ)) | (i_r << B_θ) | i_θ

Total bits per triplet: B_z + B_r + B_θ. Bits per dimension: (B_z + B_r + B_θ) / 3.

Dequantization

z_q   = z_min + (i_z + 0.5) * Δ_z          ← bin-center (unbiased)
r_q   = i_r / (C_r - 1) * r_max
θ_q   = i_θ / (C_θ - 1) * 2π - π

x_q   = r_q * cos(θ_q)
y_q   = r_q * sin(θ_q)

v_hat = R^T @ reassembled(z_q, x_q, y_q)

Error Bound

Worst-case per-triplet Euclidean reconstruction error (design doc §3.5):

‖(z, x, y) - (z_q, x_q, y_q)‖
  ≤ sqrt( (Δ_r/2)^2 + (r_max · Δ_θ/2)^2 + (Δ_z/2)^2 )

where Δ_r = r_max / (C_r - 1) and Δ_θ = 2π / (C_θ - 1).


Bit Budget

Scheme Bits per KV vector Bits/dim vs FP32
FP32 (no compression) 32d 32.0
FP16 (no compression) 16d 16.0
2-D polar, 4+4 bits 8 × (d/2) = 4d 4.0
3-D stacked-plane, 4+4+4 12 × (d/3) = 4d 4.0
3-D stacked-plane, 3+3+2 bits 8 × (d/3) = 2.67d 2.67 12×

The 3-D scheme at B_z=3, B_r=3, B_θ=2 (2.67 bits/dim) has no 2-D equivalent — you cannot reach 2.67 bits/dim with integer-bit 2-D polar. This is one regime where 3-D strictly enables smaller codebooks.


Comparison to Related Work

Method Training required Conditioning Bias correction Adaptive bits
TurboQuant (2026) None None (independent 2-D pairs) Yes (QJL) No
PrismKV v1 None z-conditioned 2-D polar No No
PrismKV v2 K-means calibration Per-z-bin learned codebooks Yes (BiasTable) Yes (entropy water-filling)
KIVI Calibration data None No No
SnapKV Fine-tuning None No No
Product Quantization Dataset training None No No

What is new in PrismKV:

  1. The triplet partition (z, x, y) with no overlapping coordinates
  2. Using the coarsely-quantized z index to select per-slice codebooks for (x, y) — a conditional product quantizer in 3-D
  3. Per-z-slice learned codebooks trained via pure-torch k-means — not possible in any 2-D scheme without a separate full-dimensional index
  4. Per-z-bin bias correction table (QJL-style, no training required beyond calibration)
  5. Water-filling adaptive bit allocation from per-head attention entropy

Quick Start

git clone https://github.com/danhicks96/PrismKV
cd PrismKV
pip install -e .
python3 examples/demo.py

Expected output (CPU, <5 seconds):

══════════════════════════════════════════════════════════════
  PrismKV  ·  3-D Stacked-Plane KV Cache Quantizer
  ...
  2D Polar (baseline)            4.0  (1024, 96)   ...
  3D Stacked-Plane (PrismKV)     4.0  (1024, 64)   ...
══════════════════════════════════════════════════════════════

Run tests

pip install -e ".[dev]"
pytest tests/ -v

All 131 tests pass (36 core, 95 eval+cache+RAG).

For the full suite including RAG and cache tests:

pip install -e ".[dev,eval,cache,rag]"
pytest tests/ -v

Repository Layout

PrismKV/
├── src/prismkv/
│   ├── quantizer/
│   │   ├── stacked_plane.py      — 3-D conditional quantizer (core prior art)
│   │   ├── baseline_2d.py        — 2-D polar baseline (TurboQuant-style)
│   │   ├── learned_codebook.py   — per-z-bin k-means codebooks (M1)
│   │   ├── bias_correction.py    — QJL-style per-z-bin bias table (M4)
│   │   └── bit_alloc.py          — water-filling adaptive bit allocation (M7)
│   ├── eval/
│   │   ├── kv_collector.py       — transformers 5.x KV hook collector (M2)
│   │   ├── benchmark.py          — RMSE / cosine / throughput benchmarks (M2)
│   │   ├── attention_entropy.py  — per-head Shannon entropy (M7)
│   │   ├── model_arch.py         — ModelArchRegistry + GQA support (M8)
│   │   └── e2e_benchmark.py      — memory table + quality report (M11)
│   ├── cache/
│   │   ├── kv_cache.py           — PrismKVCache(DynamicCache) drop-in (M3)
│   │   ├── cache_config.py       — PrismKVConfig dataclass
│   │   ├── dim_aligner.py        — pad head_dim to multiple of 3
│   │   └── cache_store.py        — save_cache / load_cache NPZ (M10)
│   └── rag/
│       ├── rag_engine.py         — RAGEngine public API (M6)
│       ├── vector_store.py       — SQLite + pure-torch cosine store
│       ├── graph_index.py        — NetworkX DiGraph + BFS expansion
│       ├── ingestion.py          — IngestionEngine with deduplication
│       ├── retriever.py          — hybrid vector + graph retrieval
│       ├── context_assembler.py  — token-budget-aware context builder
│       ├── adapters.py           — TextAdapter, DictAdapter, FileAdapter, APIAdapter
│       └── schema.py             — Chunk, Node, RetrievalResult
├── tests/                        — 170+ tests across all modules
├── examples/
│   ├── demo.py                   — 2-D vs 3-D quantizer comparison
│   ├── hf_integration.py         — GPT-2 with PrismKVCache
│   ├── rag_demo.py               — CPU-only RAG pipeline demo
│   ├── usurper_rag_demo.py       — 50-dict game-state ingestion
│   └── adaptive_demo.py          — BitAllocator → PrismKVCache
├── scripts/
│   ├── build_codebooks.py        — CLI: train learned codebooks
│   ├── collect_kv_calibration.py — extract KV tensors from GPT-2
│   └── run_e2e_benchmark.py      — CLI: memory + quality benchmark (M11)
├── design.md                     — full architecture & math specification
└── pyproject.toml

What's Shipped

Milestone Version Description
M1 0.2.0 Learned per-z-slice codebooks — k-means on real KV distributions
M2 0.2.0 KV benchmarking eval layer — RMSE, cosine sim, throughput
M3 0.2.0 PrismKVCache(DynamicCache) — drop-in HuggingFace cache replacement
M4 0.3.0 QJL-style bias correction — per-z-bin BiasTable
M5 0.4.0 CI/CD — GitHub Actions + PyPI OIDC trusted publishing
M6 0.5.0 RAG framework — vector store, graph index, adapters, RAGEngine
M7 0.6.0 Adaptive bit allocation — water-filling from attention entropy
M8 0.7.0 Multi-model support — ModelArchRegistry, GQA-aware KVCollector
M9 0.8.0 Polar-space attention approximation — novel prior-art contribution
M10 0.9.0 Cache persistence (save_cache/load_cache) + APIAdapter
M11 1.0.0 End-to-end benchmark — memory table + quality comparison

Future work

  • CUDA kernel — fused dequantization + QK dot product in a single kernel

RAG Framework (M6)

PrismKV ships a complete RAG pipeline that uses the compressed KV cache internally:

from prismkv.rag import RAGEngine
from prismkv.rag.adapters import DictAdapter

engine = RAGEngine(db_path=":memory:", embedder=my_embed_fn)

# Ingest — any adapter: text file, dict list, plain string, REST endpoint
engine.ingest(DictAdapter(game_states, entity_key="name"))

# Query
results = engine.retrieve("throne room conflict", top_k=5)

# Generate
response = engine.generate("What happened at the throne room?", generation_fn=my_llm)

Hybrid retrieval: cosine vector search + NetworkX graph BFS expansion. SHA-256 content deduplication. Token-budget-aware context assembly.

Adaptive Bit Allocation (M7)

Per-head bit budgets derived from attention entropy — sharp heads (low entropy) get more bits:

from prismkv.quantizer.bit_alloc import BitAllocator
from prismkv.cache import PrismKVCache

allocator = BitAllocator(entropy, target_avg_bits_per_dim=4.0).compute()
configs = allocator.to_prism_configs(per_head=False)  # one PrismKVConfig per layer

cache = PrismKVCache(configs=configs)

The allocator uses water-filling (sensitivity = 1/H(l,h)) with a greedy post-rounding correction that guarantees the mean bits/dim is within 1/(6n) of target after discretisation.

Multi-Model Support (M8)

Auto-detect transformer architecture and collect real KV vectors from any supported model:

from prismkv.eval.model_arch import ModelArchRegistry
from prismkv.eval.kv_collector import KVCollector

# Supports GPT-2, OPT, LLaMA/Mistral/Gemma/CodeLlama, Falcon, Qwen2, Phi
arch = ModelArchRegistry.detect(model)

collector = KVCollector(model, device="cpu")
kv_data = collector.collect(input_ids, layer_idx=0)  # {0: {"keys": ..., "values": ...}}

GQA-aware: reads num_key_value_heads from config for LLaMA-2-70B, Mistral-7B, etc.

Polar-Space Attention Approximation (M9)

Approximate attention scores directly from compressed PrismKV codes — no full dequantization:

from prismkv.quantizer.polar_attention import PolarAttentionApprox, measure_polar_approx_error

# Drop-in scaled dot-product approximation
approx = PolarAttentionApprox(
    bits_z=4, bits_r=4, bits_theta=4,
    z_min=qtz.z_min, z_max=qtz.z_max, r_max=qtz.r_max,
    scale=1/math.sqrt(head_dim), R=qtz.R,
)
output, weights = approx.forward(q, k_codes, v)  # (b, nh, sq, d), (b, nh, sq, sk)

# Measure approximation error vs exact Cartesian dot product
err = measure_polar_approx_error(q, k, k_codes, ..., R=qtz.R)
# {'mean_abs_error': ..., 'max_abs_error': ..., 'cosine_sim': ...}

The identity <q, k> = Σ_j q_z·k_z + r_q·r_k·cos(θ_q − θ_k) per triplet group enables computing attention scores from codes without materialising full FP16 key tensors.

Cache Persistence (M10)

Save and load compressed KV caches to disk:

from prismkv.cache.cache_store import save_cache, load_cache

# Serialize compressed codes + config to NPZ
save_cache(cache, "checkpoint.npz")

# Reconstruct — returns PrismKVCache with fully seeded DynamicCache layers
cache = load_cache("checkpoint.npz", device="cpu")

REST API ingestion for the RAG engine:

from prismkv.rag.adapters import APIAdapter

engine.ingest(APIAdapter(
    "https://api.example.com/articles",
    text_field="body",
    source_id="api_articles",
))

End-to-End Benchmark (M11)

Memory footprint and reconstruction quality comparison — no model download required:

from prismkv.eval.e2e_benchmark import run_e2e_benchmark, print_e2e_table

report = run_e2e_benchmark(head_dim=64, n_heads=12, n_layers=12)
print_e2e_table(report)
KV Cache Memory Footprint  (12L × 12H × d=64)
Context     FP16       3bit       4bit       5bit
  1,024    18.0MB   3.4MB(5.3×)  4.5MB(4.0×)  5.6MB(3.2×)
  4,096    72.0MB  13.5MB(5.3×) 18.0MB(4.0×) 22.5MB(3.2×)
 16,384   288.0MB  54.0MB(5.3×) 72.0MB(4.0×) 90.0MB(3.2×)

For pseudo-perplexity measurement (requires GPT-2 download):

python scripts/run_e2e_benchmark.py --pseudo-ppl

Citation / Prior Art

This repository was publicly released on 2026-03-30 as a defensive publication. If you build on these ideas, a citation is appreciated but not required under the Apache-2.0 license:

@misc{hicks2026prismkv,
  author = {Dan Hicks},
  title  = {PrismKV: 3-D Stacked-Plane KV Cache Quantization},
  year   = {2026},
  url    = {https://github.com/danhicks96/PrismKV}
}

License

Apache 2.0 — see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

prismkv-1.0.0.tar.gz (86.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

prismkv-1.0.0-py3-none-any.whl (71.0 kB view details)

Uploaded Python 3

File details

Details for the file prismkv-1.0.0.tar.gz.

File metadata

  • Download URL: prismkv-1.0.0.tar.gz
  • Upload date:
  • Size: 86.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for prismkv-1.0.0.tar.gz
Algorithm Hash digest
SHA256 6083829bdbf55df73992e78d37010b93b77d6e9ec32186066c2fd35907701dab
MD5 61655e02af95be8fe807ff67dbc7bf1f
BLAKE2b-256 be944fd0fa3d0216e75872f2e8c052fb12a6a628c5466501518f2638ed408308

See more details on using hashes here.

Provenance

The following attestation bundles were made for prismkv-1.0.0.tar.gz:

Publisher: publish.yml on danhicks96/PrismKV

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file prismkv-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: prismkv-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 71.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for prismkv-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1605ede675391badd07052e639cb1199b32e16f90670bfa41424b67f5ab706a7
MD5 544e68f30403a389512d137a97bd2d8d
BLAKE2b-256 d8a6a1cad109e8d27b95aae2ef5a643f513f2504cf410d92732ddf7f053f88ea

See more details on using hashes here.

Provenance

The following attestation bundles were made for prismkv-1.0.0-py3-none-any.whl:

Publisher: publish.yml on danhicks96/PrismKV

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page