Skip to main content

KV Cache Compression via Multi-Head Latent Attention with Riemannian Optimization

Project description

cacheshrink

KV Cache Compression via Multi-Head Latent Attention with Riemannian Optimization

Achieve 4-16x KV cache compression on LLaMA, Mistral, GPT-2, and other transformer models while maintaining model quality through mathematically principled compression and fine-tuning.

Overview

cacheshrink converts HuggingFace transformer models to use Multi-Head Latent Attention (MLA), dramatically reducing KV cache memory during inference. The library uses Riemannian optimization on Stiefel manifolds to ensure orthonormality constraints are preserved during fine-tuning, enabling high compression ratios with minimal quality loss.

Key Features

  • 4-16x KV cache compression - Reduce memory from GBs to MBs for long sequences
  • Drop-in replacement - Works with existing HuggingFace models and generation pipelines
  • Mathematically principled - Stiefel manifold constraints ensure stable compression/decompression
  • Calibration-aware initialization - SVD-based initialization using real activation statistics
  • Knowledge distillation - Fine-tune compressed models to match original model outputs
  • Multiple model support - LLaMA 2/3, Mistral, Qwen, GPT-2, and extensible to others

How It Works

The KV Cache Problem

In standard transformer attention, the KV cache grows linearly with sequence length:

Standard KV Cache Size = 2 × n_layers × seq_len × d_kv × bytes_per_element

For LLaMA-2 7B at 4096 tokens (float16):
= 2 × 32 × 4096 × 4096 × 2 = 2 GB

This becomes a major bottleneck for long-context inference and high-throughput serving.

Multi-Head Latent Attention (MLA)

MLA compresses the KV cache by projecting keys and values into a shared low-dimensional latent space:

┌─────────────────────────────────────────────────────────────────┐
│                    Standard Attention                            │
│                                                                  │
│  Hidden States ─┬─► W_k ─► K (d_kv) ──┐                         │
│       (d_model) │                      ├─► Attention ─► Output   │
│                 └─► W_v ─► V (d_kv) ──┘                         │
│                                                                  │
│  Cache stores: K + V = 2 × d_kv per token                       │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                    MLA (Multi-Head Latent Attention)            │
│                                                                  │
│                 ┌─► W_down_k ─► c_k (d_latent) ─► W_uk ─► K     │
│  Hidden States ─┤                                      │        │
│       (d_model) │                                      ├─► Attn │
│                 └─► W_down_v ─► c_v (d_latent) ─► W_uv ─► V     │
│                                                                  │
│  Cache stores: c_k + c_v = 2 × d_latent per token               │
│  Compression: d_kv / d_latent (e.g., 4x, 8x, 16x)               │
└─────────────────────────────────────────────────────────────────┘

Stiefel Manifold Constraints

The decompression matrices W_uk and W_uv are constrained to have orthonormal columns (Stiefel manifold). This ensures:

  1. Stable decompression - The projection preserves geometric relationships
  2. Energy preservation - Orthonormal matrices preserve vector norms
  3. Invertibility - Clean mathematical properties for initialization via SVD

During fine-tuning, we use Riemannian optimization (via geoopt) to maintain these constraints while updating the parameters.

Compression Math

Original KV dimension:   2 × d_kv     (K and V each have dimension d_kv)
Compressed dimension:    2 × d_latent (c_k and c_v each have dimension d_latent)

Compression ratio = (2 × d_kv) / (2 × d_latent) = d_kv / d_latent

Example (LLaMA-2 7B with 4x compression):
  d_kv = 4096 (32 heads × 128 head_dim)
  d_latent = 1024
  Compression = 4096 / 1024 = 4x

Installation

pip install cacheshrink

# For development
git clone https://github.com/your-repo/cacheshrink
cd cacheshrink
pip install -e ".[dev]"

Requirements

  • Python >= 3.9
  • PyTorch >= 2.0
  • transformers >= 4.35
  • geoopt >= 0.5 (for Riemannian optimization)

Quick Start

Basic Conversion

from cacheshrink import convert_to_mla, save_mla_model, load_mla_model

# Convert a HuggingFace model to MLA with 4x compression
model, tokenizer = convert_to_mla(
    "meta-llama/Llama-2-7b-hf",
    compression_ratio=4.0,
    device="cuda",
    dtype=torch.float16,
)

# Generate text (works exactly like the original model)
inputs = tokenizer("The theory of relativity", return_tensors="pt").to("cuda")
outputs = model.generate(inputs.input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

# Save the compressed model
save_mla_model(model, tokenizer, "./llama-7b-mla-4x")

# Load it later
model, tokenizer = load_mla_model("./llama-7b-mla-4x", device="cuda")

With Calibration (Recommended)

Calibration-aware initialization uses real activation statistics for better compression:

model, tokenizer = convert_to_mla(
    "meta-llama/Llama-2-7b-hf",
    compression_ratio=4.0,
    device="cuda",
    use_calibration=True,                    # Enable calibration
    calibration_dataset="wikitext",          # Dataset for calibration
    calibration_config="wikitext-2-raw-v1",
    num_calibration_samples=128,             # More samples = better init
    max_calibration_length=512,
)

Fine-tuning with Knowledge Distillation

For best results, fine-tune the compressed model using knowledge distillation:

from cacheshrink import MLATrainer
from transformers import AutoModelForCausalLM

# Load original model as teacher
teacher = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto",
)

# Create trainer with Riemannian optimization
trainer = MLATrainer(
    model=model,
    tokenizer=tokenizer,
    teacher_model=teacher,
    euclidean_lr=1e-5,      # Learning rate for W_down (compression)
    riemannian_lr=1e-4,     # Learning rate for W_uk, W_uv (decompression)
    use_distillation=True,  # Match teacher's output distribution
)

# Train on your data
trainer.train(
    train_texts,           # List of strings or HuggingFace dataset
    num_epochs=10,
    batch_size=4,
    max_length=512,
)

Evaluation

from cacheshrink import compute_perplexity, measure_cache_memory

# Measure perplexity
ppl = compute_perplexity(model, tokenizer, eval_texts, max_length=512)
print(f"Perplexity: {ppl:.2f}")

# Analyze KV cache compression
stats = measure_cache_memory(model, sequence_lengths=[512, 1024, 2048, 4096])

print(f"Compression ratio: {stats['compression_ratio']:.1f}x")
for seq_len, info in stats['per_sequence_length'].items():
    print(f"  {seq_len} tokens: {info['standard_cache_formatted']} -> {info['mla_cache_formatted']}")

Benchmark Results

LLaMA-2 7B with 4x Compression

Metric Original After MLA After Fine-tuning
Perplexity (WikiText-2) 11.75 16.61 (+41%) 13.36 (+14%)
KV Cache @ 2048 tokens 1.00 GB 256 MB 256 MB
Memory Saved - 768 MB 768 MB

Memory Savings by Sequence Length (4x compression)

Sequence Length Standard Cache MLA Cache Saved
512 256 MB 64 MB 192 MB
1024 512 MB 128 MB 384 MB
2048 1.00 GB 256 MB 768 MB
4096 2.00 GB 512 MB 1.5 GB
8192 4.00 GB 1.00 GB 3.0 GB

Supported Models

Model Family Status Notes
LLaMA 2/3 Supported Full support including GQA
Mistral Supported Based on LLaMA handler
Qwen/Qwen2 Supported Based on LLaMA handler
GPT-2 Supported Combined QKV projection

Adding New Models

Extend ModelHandler to support additional architectures:

from cacheshrink import ModelHandler, register_handler

class MyModelHandler(ModelHandler):
    def get_num_layers(self) -> int:
        return len(self.model.my_layers)

    def extract_qkv_weights(self, layer_idx: int):
        layer = self.model.my_layers[layer_idx]
        return layer.W_q, layer.W_k, layer.W_v, layer.W_o

    # ... implement other methods

register_handler("my_model", MyModelHandler)

API Reference

Core Functions

# Conversion
convert_to_mla(
    model_name_or_path: str,
    compression_ratio: float = 4.0,      # Target compression (4-16x typical)
    d_latent: int = None,                # Override latent dim (auto if None)
    device: str = "cuda",
    dtype: torch.dtype = torch.float16,
    use_calibration: bool = True,        # Use activation statistics
    num_calibration_samples: int = 128,
    verbose: bool = True,
) -> Tuple[nn.Module, Tokenizer]

# Save/Load
save_mla_model(model, tokenizer, path: str)
load_mla_model(path: str, device: str = "cuda") -> Tuple[nn.Module, Tokenizer]

Training

MLATrainer(
    model: nn.Module,
    tokenizer: Tokenizer,
    teacher_model: nn.Module = None,     # For distillation
    euclidean_lr: float = 1e-5,          # W_down learning rate
    riemannian_lr: float = 1e-4,         # W_uk, W_uv learning rate
    use_distillation: bool = True,       # Recommended
)

Evaluation

compute_perplexity(model, tokenizer, texts, max_length=512) -> float
measure_cache_memory(model, sequence_lengths=[128, 512, 2048]) -> dict
generate_samples(model, tokenizer, prompts, max_new_tokens=50) -> List[str]

Configuration

MLAConfig(
    model_name: str,
    model_type: str,           # "llama", "mistral", "gpt2", "qwen"
    compression_ratio: float,  # 4.0, 8.0, 16.0, etc.
    d_latent: int = None,      # Auto-computed if None
    n_heads: int,
    n_kv_heads: int,           # For GQA models
    d_model: int,
    d_head: int,
    n_layers: int,
)

Compression Methods

cacheshrink supports three compression methods, selectable via the compression_method parameter:

model, tokenizer = convert_to_mla(
    "meta-llama/Llama-2-7b-hf",
    compression_ratio=4.0,
    compression_method="separate",  # "separate", "joint", or "decoupled_rope"
)

Separate K/V Compression (Default, Recommended)

The default method compresses K and V independently:

  • Cache stores [c_k, c_v] where each has dimension d_latent
  • Best reconstruction quality for post-hoc conversion
  • Recommended for all use cases

Joint K/V Compression (Experimental)

Warning: Joint compression does not work well for post-hoc conversion of pre-trained models. It achieves 85-300% reconstruction error compared to 10-75% for separate compression. Only use this if you plan to train a model from scratch with this architecture.

DeepSeek-style compression with a single shared latent:

  • Cache stores only c (single latent for both K and V)
  • 2x more memory efficient than separate at same d_latent
  • Works well when models are trained from scratch with joint compression
  • Does NOT work well for converting existing pre-trained models because K and V have different statistical structures

Decoupled RoPE Compression (Experimental)

Warning: This method is experimental. It preserves a portion of the keys/values uncompressed for positional information, which limits the maximum achievable compression ratio.

Separates positional encoding from compressed content:

  • Keeps d_rope dimensions uncompressed for RoPE
  • Compresses remaining dimensions
  • Effective compression is limited by d_rope (e.g., with d_rope=64, max compression at 8x is only achievable if d_kv > 128)
model, tokenizer = convert_to_mla(
    "meta-llama/Llama-2-7b-hf",
    compression_ratio=4.0,
    compression_method="decoupled_rope",
    d_rope=64,  # Uncompressed dimensions for positional info
)

Advanced Usage

Custom Initialization

from cacheshrink import balanced_svd_init, MLACompression

# Manual SVD initialization
W_down_k, W_down_v, W_uk, W_uv = balanced_svd_init(
    W_k=original_key_weights,
    W_v=original_value_weights,
    d_latent=1024,
)

# Create compression module
compression = MLACompression(config)
compression.init_from_weights(W_down_k, W_down_v, W_uk, W_uv)

Checking Orthonormality

# Verify Stiefel constraints are maintained
for layer_idx in range(model.mla_config.n_layers):
    attn = model.model.layers[layer_idx].self_attn
    errors = attn.mla.check_orthonormality()
    print(f"Layer {layer_idx}: W_uk error={errors['W_uk'][0]:.2e}, W_uv error={errors['W_uv'][0]:.2e}")

Training Configuration

from cacheshrink import TrainingConfig

config = TrainingConfig(
    euclidean_lr=1e-5,
    riemannian_lr=1e-4,
    num_epochs=10,
    batch_size=4,
    max_length=512,
    use_distillation=True,
    distillation_temperature=2.0,
    distillation_alpha=0.9,           # 90% distillation, 10% LM loss
    check_orthonormality_steps=100,   # Monitor constraint satisfaction
)

How the Training Works

Two-Optimizer Approach

cacheshrink uses separate optimizers for different parameter types:

  1. AdamW for Euclidean parameters (W_down_k, W_down_v)

    • Standard gradient descent
    • Updates compression matrices freely
  2. RiemannianAdam for Stiefel parameters (W_uk, W_uv)

    • Gradients are projected to tangent space of Stiefel manifold
    • Updates follow geodesics to maintain orthonormality
    • No explicit re-orthonormalization needed

Knowledge Distillation

Instead of training on language modeling loss alone (which can cause the compressed model to drift), we use knowledge distillation:

Loss = α × KL(student || teacher) + (1-α) × LM_loss

where:
  - student = softmax(MLA_logits / T)
  - teacher = softmax(original_logits / T)
  - T = temperature (default 2.0)
  - α = distillation weight (default 0.9)

This keeps the compressed model's behavior close to the original.

Citation

If you use cacheshrink in your research, please cite:

@software{cacheshrink2024,
  title = {cacheshrink: KV Cache Compression via Multi-Head Latent Attention},
  year = {2024},
  url = {https://github.com/your-repo/cacheshrink}
}

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cacheshrink-0.1.1.tar.gz (60.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cacheshrink-0.1.1-py3-none-any.whl (55.6 kB view details)

Uploaded Python 3

File details

Details for the file cacheshrink-0.1.1.tar.gz.

File metadata

  • Download URL: cacheshrink-0.1.1.tar.gz
  • Upload date:
  • Size: 60.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cacheshrink-0.1.1.tar.gz
Algorithm Hash digest
SHA256 6ee0a6cf0b46f54d6de1e30c3b8d6c3718f8ccef1283fb859a86688734c33248
MD5 e521d2c2650968fd4e236fbe7ddc1455
BLAKE2b-256 5009a96cf31f03a0ac69b76337dc5724aee2854b3c2c147e9692c48b42d65042

See more details on using hashes here.

File details

Details for the file cacheshrink-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: cacheshrink-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 55.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cacheshrink-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 862875920bedd92a5ab8dc0e1a9ae09e938a4647a9414202b125e0c5c488d26f
MD5 e6355bf62b2beeae32c900737056e977
BLAKE2b-256 0eca3fc5e9963fd9101bb3896be74e3a24ebdea095fd03d3eb3ae3648dc6ac96

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page