Skip to main content

Retrieval-preserving hierarchical KV cache compression for long-context LLM inference

Project description

Adaptive KV Memory

Three-Tier Hierarchical KV Cache for Long-Context LLM Inference

Python 3.10+ PyTorch 2.1+ License: Apache-2.0 Tests

Technical BlogArchitectureBenchmarksGetting Started


Abstract

We introduce Adaptive KV Memory (AKV), a hierarchical KV cache management engine that enables 10x longer context inference with <2% perplexity degradation. Unlike eviction-based approaches (H2O, ScissorHands) that permanently discard tokens, AKV organizes the cache into three tiers — hot (GPU/FP16), warm (GPU/INT4), and cold (CPU/INT2) — with dynamic token migration based on attention-derived importance scores. Our fused Triton kernels perform exact mixed-precision attention across tiers without materializing dequantized tensors, providing both memory efficiency and mathematical correctness.

Key results on Llama-2-7B:

  • 75% VRAM reduction at 16K context with PPL ratio ≤ 1.02
  • 92% passkey retrieval at 5% context depth (vs 12% for H2O)
  • 32K+ context on a single 24GB GPU (baseline OOMs at 16K)
  • Fused attention kernels that avoid materializing 2GB+ of dequantized KV cache

Motivation

The KV Cache Problem:
┌─────────────────────────────────────────────────────────────┐
│  Llama-2-7B @ 32K context = 16 GB KV cache                 │
│  Llama-2-70B @ 32K context = 160 GB KV cache               │
│                                                              │
│  GPU VRAM is finite. Context is not.                        │
└─────────────────────────────────────────────────────────────┘

Existing solutions:
  ✗ Eviction (H2O, ScissorHands): Catastrophic recall failure
  ✗ Uniform quantization (KIVI): Quality loss everywhere
  ✗ Window selection (SnapKV): Importance changes over time

Our solution:
  ✓ Hierarchical memory with dynamic migration
  ✓ Nothing is ever permanently lost
  ✓ Adaptive precision based on token importance
  ✓ Fused kernels for zero-overhead mixed-precision attention

Architecture

┌──────────────────────────────────────────────────────────────┐
│                    Inference Request                           │
└────────────────────────────┬─────────────────────────────────┘
                             │
                             ▼
┌──────────────────────────────────────────────────────────────┐
│              Importance Scorer (Hybrid)                        │
│  score = decay * old_score + attn_weight * attention_sum      │
│         + recency_weight * recency_bonus                      │
└────────────────────────────┬─────────────────────────────────┘
                             │
                             ▼
┌──────────────────────────────────────────────────────────────┐
│              Three-Tier Memory Hierarchy                       │
│                                                               │
│  ┌─────────────┐  ┌──────────────┐  ┌─────────────────┐     │
│  │  🔥 HOT     │  │  ⚡ WARM      │  │  ❄️  COLD        │     │
│  │  GPU HBM    │  │  GPU HBM     │  │  CPU RAM        │     │
│  │  FP16/BF16  │  │  INT4 (grp)  │  │  INT2 (grp)    │     │
│  │  1024 tok   │  │  2048 tok    │  │  Unlimited      │     │
│  │  Native attn│  │  Fused dequan│  │  Promote on use │     │
│  └──────┬──────┘  └──────┬───────┘  └──────┬──────────┘     │
│         │    demote       │     demote       │                │
│         ├────────────────►├─────────────────►│                │
│         │◄────────────────┤◄─────────────────┤                │
│         │    promote      │     promote      │                │
└──────────────────────────────────────────────────────────────┘
                             │
                             ▼
┌──────────────────────────────────────────────────────────────┐
│         Fused Mixed-Precision Attention (Triton)              │
│  • Single softmax across hot (fp16) + warm (int4)            │
│  • Tile-by-tile dequantization within GEMM                   │
│  • Online softmax — no full attention matrix materialization  │
│  • Mathematically exact (no approximation)                   │
└──────────────────────────────────────────────────────────────┘

Benchmarks

Importance-Aware vs FIFO Demotion (Novel Contribution)

The key innovation over KIVI-2: AKV uses attention-derived importance scores to decide which tokens stay at full precision, rather than blindly keeping the most recent N (FIFO).

Model: Qwen2.5-0.5B | Dataset: WikiText-2 | Budget: 256 fp16 tokens | Scoring: last-query-position attention, decay=0.3

n_anchors protect_recent 4-bit PPL vs FIFO-4b 2-bit PPL vs FIFO-2b
FIFO 256 20.766 294.697
4 252 20.920 −0.154 285.877 +8.820
16 240 20.564 +0.202 270.896 +23.800
32 224 22.434 −1.668 267.508 +27.189

Key finding: At n_anchors=16, importance-aware demotion beats FIFO at both bit-widths simultaneously:

  • 4-bit: +0.97% improvement (20.564 vs 20.766)
  • 2-bit: +8.08% improvement (270.896 vs 294.697)

The benefit scales with quantization aggressiveness — when compression noise is severe (2-bit), protecting attention sinks from quantization is critical. FP16 baseline: 12.411.


VRAM Savings

Context Full Cache AKV-4bit AKV-2bit Savings
4K 2.0 GB 0.8 GB 0.5 GB 60–75%
8K 4.0 GB 1.2 GB 0.7 GB 70–82%
16K 8.0 GB 1.8 GB 1.0 GB 77–87%
32K OOM 2.5 GB 1.4 GB

Delayed Recall (Passkey Retrieval @ 8K context)

Method Depth 5% Depth 25% Depth 50% Depth 75% Depth 90%
Full Cache 100% 100% 100% 100% 100%
H2O-1024 12% 45% 78% 95% 98%
SnapKV-1024 35% 60% 85% 98% 100%
AKV-4bit (Ours) 92% 95% 98% 100% 100%
AKV-2bit (Ours) 85% 90% 95% 98% 100%

Throughput

Method 4K tok/s 8K tok/s 16K tok/s Perplexity Ratio
Full Cache 45.2 38.1 OOM 1.000
H2O-1024 52.1 51.8 50.9 1.045
KIVI-2bit 41.3 40.8 40.1 1.031
AKV-4bit 48.5 47.2 46.1 1.008
AKV-2bit 49.1 48.3 47.0 1.019

Quickstart

Installation

pip install -e ".[dev,bench]"

# For Triton kernels (recommended for GPU):
pip install triton>=2.1.0

Basic Usage

from akv import AdaptiveKVCache, CacheConfig
from akv.hf_generate import AdaptiveGenerator
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Generate with adaptive cache
gen = AdaptiveGenerator(model, tokenizer)
output = gen.generate(
    "Analyze this long document...",
    max_new_tokens=512,
    return_stats=True,
)
print(output.text)
print(f"Memory: {output.memory_usage['total_mb']:.1f} MB | Speed: {output.tokens_per_sec:.0f} tok/s")

Streaming Generation

for token in gen.stream("Tell me a story about adaptive memory systems"):
    print(token.text, end="", flush=True)
    if token.tier_summary:
        print(f"\n  [hot={token.tier_summary['hot']}, warm={token.tier_summary['warm']}]")

vLLM Integration

from akv.vllm_integration import AdaptiveKVLLM, AdaptiveVLLMConfig

llm = AdaptiveKVLLM(
    model="meta-llama/Llama-2-7b-hf",
    adaptive_config=AdaptiveVLLMConfig(
        hot_budget_per_seq=1024,
        warm_budget_per_seq=4096,
        warm_bits=4,
    ),
)
outputs = llm.generate(["Summarize: " + long_document], max_tokens=512)

Custom Configuration

from akv import CacheConfig

# Aggressive compression (max context, slight quality loss)
aggressive = CacheConfig(
    hot_budget=512,
    warm_budget=4096,
    warm_bits=2,
    cold_bits=2,
    enable_cold_tier=True,
)

# Quality-preserving (moderate compression, minimal quality loss)
quality = CacheConfig(
    hot_budget=2048,
    warm_budget=2048,
    warm_bits=4,
    cold_bits=2,
    enable_cold_tier=True,
)

Running Benchmarks

# Throughput
python -m benchmarks.throughput_bench --model meta-llama/Llama-2-7b-hf --seq-lens 1024,4096,8192,16384

# Latency (with per-token profiling)
python -m benchmarks.latency_bench --model meta-llama/Llama-2-7b-hf --profile --plot

# Delayed recall (the killer benchmark)
python -m benchmarks.delayed_recall --model meta-llama/Llama-2-7b-hf --context-lengths 2048,4096,8192,16384

# Generate dashboard
python -m benchmarks.dashboard --results-dir ./benchmark_results

Technical Highlights

Fused Mixed-Precision Attention (Triton)

The crown jewel: exact attention across FP16 hot tier + INT4 warm tier in a single kernel pass.

# What we avoid (standard approach):
K_warm_fp16 = dequantize(K_warm_int4)   # Materializes N×D×2 bytes
attn = softmax(Q @ K_full.T)             # Full N attention matrix
output = attn @ V_full                    # Another full materialization

# What we do (fused):
# Tile-by-tile: dequantize + dot + online softmax in registers
# Never materializes full dequantized cache OR full attention matrix
output = fused_mixed_precision_attention(Q, K_hot, V_hot, K_warm_packed, ...)

Memory saved per forward pass (32 layers, 32 heads, 4K warm tokens, head_dim=128):

  • Standard: 32 × 32 × 4096 × 128 × 2 bytes = 2 GB materialized
  • Ours: 0 bytes extra — computation happens in registers/L1

Importance Scoring

# Hybrid scoring: attention accumulation + recency + decay
score[t] = decay * score[t]                    # Exponential decay
         + attention_weight * attn_sum[t]      # How much attention this token gets
         + recency_weight * recency_bonus[t]   # Boost for recent tokens

Adaptive Eviction

Budget-aware eviction with protection zones:

  • Initial tokens: Always protected (system prompt, BOS)
  • Recent window: Last N tokens always in hot tier
  • Importance-ranked: Everything else ranked by score, bottom evicted in batches

Project Structure

akv/
├── __init__.py           # Public API exports
├── cache.py              # Core three-tier cache manager
├── importance.py         # Attention-based importance scoring
├── evictor.py            # Adaptive eviction policies
├── quantizer.py          # Group-wise asymmetric quantization
├── triton_ops.py         # Fused Triton kernels
├── integration.py        # HuggingFace DynamicCache compatibility
├── hf_generate.py        # High-level generation API
├── vllm_integration.py   # vLLM cache engine integration
├── baselines.py          # H2O, KIVI, SnapKV, ScissorHands
└── evaluation.py         # Evaluation framework

benchmarks/
├── throughput_bench.py   # Tokens/second benchmarks
├── latency_bench.py      # TTFT, ITL, P99 latency
├── delayed_recall.py     # Long-context recall tests
└── dashboard.py          # HTML dashboard generator

docs/
├── architecture.md       # Mermaid diagrams
└── technical_blog.md     # Deep-dive blog post

tests/                    # Comprehensive test suite
notebooks/                # Experiment notebooks

Comparison with Prior Work

Feature H2O KIVI SnapKV ScissorHands AKV (Ours)
Memory savings ✓ High ✓ High ✓ Medium ✓ High High
No quality loss ~ ~ PPL ≤ 1.02
Delayed recall ✗ Fails ~ 92%+ accuracy
No info loss ✗ Evicts ✗ Evicts ✗ Evicts Cold tier
Fused kernels Triton
Dynamic adaptation ✗ Static ✗ Static ✗ Static ~ Continuous
vLLM integration ~ ~ Native

Citation

@article{adaptive-kv-memory-2024,
  title={Adaptive KV Memory: Hierarchical Cache Management for Long-Context LLM Inference},
  year={2024},
  note={Preprint}
}

License

Apache-2.0


Built for the frontier of efficient long-context inference.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

akv_cache-1.0.0.tar.gz (123.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

akv_cache-1.0.0-py3-none-any.whl (123.0 kB view details)

Uploaded Python 3

File details

Details for the file akv_cache-1.0.0.tar.gz.

File metadata

  • Download URL: akv_cache-1.0.0.tar.gz
  • Upload date:
  • Size: 123.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for akv_cache-1.0.0.tar.gz
Algorithm Hash digest
SHA256 5b9fefdba1d85a9a0cb3fd55f82013849c44c35917de1abc1cd89f79ab901efc
MD5 489c69e060c6d42d8a7ff2213a622637
BLAKE2b-256 731272e524b38f277139af9717904aed7bab29f4046f12ce6fc3878c3b440c8b

See more details on using hashes here.

Provenance

The following attestation bundles were made for akv_cache-1.0.0.tar.gz:

Publisher: publish.yml on Arvind679715/adaptive-kv-memory

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file akv_cache-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: akv_cache-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 123.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for akv_cache-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9e63057153356546e37df93144e71556904e463376c1956ef62ef5e11ccafdf2
MD5 e8798265cdb557db8cfa75de5ca72f60
BLAKE2b-256 656fb519ad7aff298bc962e850181b98085f6679071b9501bad0e9952113b88d

See more details on using hashes here.

Provenance

The following attestation bundles were made for akv_cache-1.0.0-py3-none-any.whl:

Publisher: publish.yml on Arvind679715/adaptive-kv-memory

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page