Skip to main content

KV Cache Compression via Multi-Head Latent Attention with Riemannian Optimization

Project description

cacheshrink

KV Cache Compression via Multi-Head Latent Attention with Riemannian Optimization

Achieve 4-16x KV cache compression on LLaMA, Mistral, GPT-2, and other transformer models while maintaining model quality through mathematically principled compression and fine-tuning.

Overview

cacheshrink converts HuggingFace transformer models to use Multi-Head Latent Attention (MLA), dramatically reducing KV cache memory during inference. The library uses Riemannian optimization on Stiefel manifolds to ensure orthonormality constraints are preserved during fine-tuning, enabling high compression ratios with minimal quality loss.

Key Features

  • 4-16x KV cache compression - Reduce memory from GBs to MBs for long sequences
  • Drop-in replacement - Works with existing HuggingFace models and generation pipelines
  • Mathematically principled - Stiefel manifold constraints ensure stable compression/decompression
  • Calibration-aware initialization - SVD-based initialization using real activation statistics
  • Knowledge distillation - Fine-tune compressed models to match original model outputs
  • Multiple model support - LLaMA 2/3, Mistral, Qwen, GPT-2, and extensible to others

How It Works

The KV Cache Problem

In standard transformer attention, the KV cache grows linearly with sequence length:

Standard KV Cache Size = 2 × n_layers × seq_len × d_kv × bytes_per_element

For LLaMA-2 7B at 4096 tokens (float16):
= 2 × 32 × 4096 × 4096 × 2 = 2 GB

This becomes a major bottleneck for long-context inference and high-throughput serving.

Multi-Head Latent Attention (MLA)

MLA compresses the KV cache by projecting keys and values into a shared low-dimensional latent space:

┌─────────────────────────────────────────────────────────────────┐
│                    Standard Attention                            │
│                                                                  │
│  Hidden States ─┬─► W_k ─► K (d_kv) ──┐                         │
│       (d_model) │                      ├─► Attention ─► Output   │
│                 └─► W_v ─► V (d_kv) ──┘                         │
│                                                                  │
│  Cache stores: K + V = 2 × d_kv per token                       │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                    MLA (Multi-Head Latent Attention)            │
│                                                                  │
│                 ┌─► W_down_k ─► c_k (d_latent) ─► W_uk ─► K     │
│  Hidden States ─┤                                      │        │
│       (d_model) │                                      ├─► Attn │
│                 └─► W_down_v ─► c_v (d_latent) ─► W_uv ─► V     │
│                                                                  │
│  Cache stores: c_k + c_v = 2 × d_latent per token               │
│  Compression: d_kv / d_latent (e.g., 4x, 8x, 16x)               │
└─────────────────────────────────────────────────────────────────┘

Stiefel Manifold Constraints

The decompression matrices W_uk and W_uv are constrained to have orthonormal columns (Stiefel manifold). This ensures:

  1. Stable decompression - The projection preserves geometric relationships
  2. Energy preservation - Orthonormal matrices preserve vector norms
  3. Invertibility - Clean mathematical properties for initialization via SVD

During fine-tuning, we use Riemannian optimization (via geoopt) to maintain these constraints while updating the parameters.

Compression Math

Original KV dimension:   2 × d_kv     (K and V each have dimension d_kv)
Compressed dimension:    2 × d_latent (c_k and c_v each have dimension d_latent)

Compression ratio = (2 × d_kv) / (2 × d_latent) = d_kv / d_latent

Example (LLaMA-2 7B with 4x compression):
  d_kv = 4096 (32 heads × 128 head_dim)
  d_latent = 1024
  Compression = 4096 / 1024 = 4x

Installation

pip install cacheshrink

# For development
git clone https://github.com/your-repo/cacheshrink
cd cacheshrink
pip install -e ".[dev]"

Requirements

  • Python >= 3.9
  • PyTorch >= 2.0
  • transformers >= 4.35
  • geoopt >= 0.5 (for Riemannian optimization)

Quick Start

Basic Conversion

from cacheshrink import convert_to_mla, save_mla_model, load_mla_model

# Convert a HuggingFace model to MLA with 4x compression
model, tokenizer = convert_to_mla(
    "meta-llama/Llama-2-7b-hf",
    compression_ratio=4.0,
    device="cuda",
    dtype=torch.float16,
)

# Generate text (works exactly like the original model)
inputs = tokenizer("The theory of relativity", return_tensors="pt").to("cuda")
outputs = model.generate(inputs.input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

# Save the compressed model
save_mla_model(model, tokenizer, "./llama-7b-mla-4x")

# Load it later
model, tokenizer = load_mla_model("./llama-7b-mla-4x", device="cuda")

With Calibration (Recommended)

Calibration-aware initialization uses real activation statistics for better compression:

model, tokenizer = convert_to_mla(
    "meta-llama/Llama-2-7b-hf",
    compression_ratio=4.0,
    device="cuda",
    use_calibration=True,                    # Enable calibration
    calibration_dataset="wikitext",          # Dataset for calibration
    calibration_config="wikitext-2-raw-v1",
    num_calibration_samples=128,             # More samples = better init
    max_calibration_length=512,
)

Fine-tuning with Knowledge Distillation

For best results, fine-tune the compressed model using knowledge distillation:

from cacheshrink import MLATrainer
from transformers import AutoModelForCausalLM

# Load original model as teacher
teacher = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto",
)

# Create trainer with Riemannian optimization
trainer = MLATrainer(
    model=model,
    tokenizer=tokenizer,
    teacher_model=teacher,
    euclidean_lr=1e-5,      # Learning rate for W_down (compression)
    riemannian_lr=1e-4,     # Learning rate for W_uk, W_uv (decompression)
    use_distillation=True,  # Match teacher's output distribution
)

# Train on your data
trainer.train(
    train_texts,           # List of strings or HuggingFace dataset
    num_epochs=10,
    batch_size=4,
    max_length=512,
)

Evaluation

from cacheshrink import compute_perplexity, measure_cache_memory

# Measure perplexity
ppl = compute_perplexity(model, tokenizer, eval_texts, max_length=512)
print(f"Perplexity: {ppl:.2f}")

# Analyze KV cache compression
stats = measure_cache_memory(model, sequence_lengths=[512, 1024, 2048, 4096])

print(f"Compression ratio: {stats['compression_ratio']:.1f}x")
for seq_len, info in stats['per_sequence_length'].items():
    print(f"  {seq_len} tokens: {info['standard_cache_formatted']} -> {info['mla_cache_formatted']}")

Benchmark Results

LLaMA-2 7B with 4x Compression

Metric Original After MLA After Fine-tuning
Perplexity (WikiText-2) 11.75 16.61 (+41%) 13.36 (+14%)
KV Cache @ 2048 tokens 1.00 GB 256 MB 256 MB
Memory Saved - 768 MB 768 MB

Memory Savings by Sequence Length (4x compression)

Sequence Length Standard Cache MLA Cache Saved
512 256 MB 64 MB 192 MB
1024 512 MB 128 MB 384 MB
2048 1.00 GB 256 MB 768 MB
4096 2.00 GB 512 MB 1.5 GB
8192 4.00 GB 1.00 GB 3.0 GB

Supported Models

Model Family Status Notes
LLaMA 2/3 Supported Full support including GQA
Mistral Supported Based on LLaMA handler
Qwen/Qwen2 Supported Based on LLaMA handler
GPT-2 Supported Combined QKV projection

Adding New Models

Extend ModelHandler to support additional architectures:

from cacheshrink import ModelHandler, register_handler

class MyModelHandler(ModelHandler):
    def get_num_layers(self) -> int:
        return len(self.model.my_layers)

    def extract_qkv_weights(self, layer_idx: int):
        layer = self.model.my_layers[layer_idx]
        return layer.W_q, layer.W_k, layer.W_v, layer.W_o

    # ... implement other methods

register_handler("my_model", MyModelHandler)

API Reference

Core Functions

# Conversion
convert_to_mla(
    model_name_or_path: str,
    compression_ratio: float = 4.0,      # Target compression (4-16x typical)
    d_latent: int = None,                # Override latent dim (auto if None)
    device: str = "cuda",
    dtype: torch.dtype = torch.float16,
    use_calibration: bool = True,        # Use activation statistics
    num_calibration_samples: int = 128,
    verbose: bool = True,
) -> Tuple[nn.Module, Tokenizer]

# Save/Load
save_mla_model(model, tokenizer, path: str)
load_mla_model(path: str, device: str = "cuda") -> Tuple[nn.Module, Tokenizer]

Training

MLATrainer(
    model: nn.Module,
    tokenizer: Tokenizer,
    teacher_model: nn.Module = None,     # For distillation
    euclidean_lr: float = 1e-5,          # W_down learning rate
    riemannian_lr: float = 1e-4,         # W_uk, W_uv learning rate
    use_distillation: bool = True,       # Recommended
)

Evaluation

compute_perplexity(model, tokenizer, texts, max_length=512) -> float
measure_cache_memory(model, sequence_lengths=[128, 512, 2048]) -> dict
generate_samples(model, tokenizer, prompts, max_new_tokens=50) -> List[str]

Configuration

MLAConfig(
    model_name: str,
    model_type: str,           # "llama", "mistral", "gpt2", "qwen"
    compression_ratio: float,  # 4.0, 8.0, 16.0, etc.
    d_latent: int = None,      # Auto-computed if None
    n_heads: int,
    n_kv_heads: int,           # For GQA models
    d_model: int,
    d_head: int,
    n_layers: int,
)

Advanced Usage

Custom Initialization

from cacheshrink import balanced_svd_init, MLACompression

# Manual SVD initialization
W_down_k, W_down_v, W_uk, W_uv = balanced_svd_init(
    W_k=original_key_weights,
    W_v=original_value_weights,
    d_latent=1024,
)

# Create compression module
compression = MLACompression(config)
compression.init_from_weights(W_down_k, W_down_v, W_uk, W_uv)

Checking Orthonormality

# Verify Stiefel constraints are maintained
for layer_idx in range(model.mla_config.n_layers):
    attn = model.model.layers[layer_idx].self_attn
    errors = attn.mla.check_orthonormality()
    print(f"Layer {layer_idx}: W_uk error={errors['W_uk'][0]:.2e}, W_uv error={errors['W_uv'][0]:.2e}")

Training Configuration

from cacheshrink import TrainingConfig

config = TrainingConfig(
    euclidean_lr=1e-5,
    riemannian_lr=1e-4,
    num_epochs=10,
    batch_size=4,
    max_length=512,
    use_distillation=True,
    distillation_temperature=2.0,
    distillation_alpha=0.9,           # 90% distillation, 10% LM loss
    check_orthonormality_steps=100,   # Monitor constraint satisfaction
)

How the Training Works

Two-Optimizer Approach

cacheshrink uses separate optimizers for different parameter types:

  1. AdamW for Euclidean parameters (W_down_k, W_down_v)

    • Standard gradient descent
    • Updates compression matrices freely
  2. RiemannianAdam for Stiefel parameters (W_uk, W_uv)

    • Gradients are projected to tangent space of Stiefel manifold
    • Updates follow geodesics to maintain orthonormality
    • No explicit re-orthonormalization needed

Knowledge Distillation

Instead of training on language modeling loss alone (which can cause the compressed model to drift), we use knowledge distillation:

Loss = α × KL(student || teacher) + (1-α) × LM_loss

where:
  - student = softmax(MLA_logits / T)
  - teacher = softmax(original_logits / T)
  - T = temperature (default 2.0)
  - α = distillation weight (default 0.9)

This keeps the compressed model's behavior close to the original.

Citation

If you use cacheshrink in your research, please cite:

@software{cacheshrink2024,
  title = {cacheshrink: KV Cache Compression via Multi-Head Latent Attention},
  year = {2024},
  url = {https://github.com/your-repo/cacheshrink}
}

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cacheshrink-0.1.0.tar.gz (54.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cacheshrink-0.1.0-py3-none-any.whl (50.1 kB view details)

Uploaded Python 3

File details

Details for the file cacheshrink-0.1.0.tar.gz.

File metadata

  • Download URL: cacheshrink-0.1.0.tar.gz
  • Upload date:
  • Size: 54.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for cacheshrink-0.1.0.tar.gz
Algorithm Hash digest
SHA256 2154692ab0695ab95529a25bc14029ca40de8110821ede1d0b7d5ce1f3edd4e5
MD5 29ca927df5fcccc8265384c84a645fdf
BLAKE2b-256 d9b409d8639b54f63cb7b8b4b6932af343aa872f90baac34c2953137f08c362c

See more details on using hashes here.

File details

Details for the file cacheshrink-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: cacheshrink-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 50.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for cacheshrink-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d133b5dc043dd253362d340e00bc2e9ff9f86365daa8185971950d6e73f5dacd
MD5 ede38b78edeca3f5f2f3b0eac43a3ad0
BLAKE2b-256 fc8d22e9ca913dcd2a40b2c9814c8d0bb6c09a5dfbe405fd8bc1a630b182a0ed

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page