KV Cache Compression via Multi-Head Latent Attention with Riemannian Optimization
Project description
cacheshrink
KV Cache Compression via Multi-Head Latent Attention with Riemannian Optimization
Achieve 2-16x KV cache compression on LLaMA, Mistral, Qwen, GPT-2, and other transformer models while maintaining model quality through mathematically principled compression and fine-tuning.
Table of Contents
- Overview
- Installation
- Quick Start
- How It Works
- Choosing a Compression Method
- Usage Guide
- Parameter Reference
- Training Details
- Benchmark Results
- Supported Models
- Adding New Models
- License
Overview
During inference, transformer models store keys and values (the KV cache) for every token in the sequence. This cache grows linearly with sequence length and becomes the dominant memory bottleneck for long-context and high-throughput serving:
Standard KV Cache Size = 2 x n_layers x seq_len x d_kv x bytes_per_element
Qwen2.5-7B at 4096 tokens (float16):
= 2 x 28 x 4096 x 512 x 2 = 224 MB
LLaMA-2 7B at 4096 tokens (float16):
= 2 x 32 x 4096 x 4096 x 2 = 2 GB
cacheshrink compresses this cache by converting attention layers to use Multi-Head Latent Attention (MLA). Instead of caching full-dimensional keys and values, it caches compact latent representations and reconstructs K/V on the fly during attention.
Key Features
- 2-16x KV cache compression with minimal quality loss
- Drop-in replacement for HuggingFace models — generation, evaluation, and pipelines work unchanged
- Mathematically principled — Stiefel manifold constraints ensure stable compression/decompression
- Calibration-aware initialization — SVD on real activation statistics for high-quality starting point
- Two training approaches — reconstruction loss (no teacher model) or knowledge distillation
- Automatic method selection — detects MHA/GQA/MQA and picks the optimal compression strategy
- Cross-layer compression (xKV) — shared basis vectors across layer groups for GQA models
Installation
pip install cacheshrink
# For development
git clone https://github.com/your-repo/cacheshrink
cd cacheshrink
pip install -e ".[dev]"
Requirements
- Python >= 3.9
- PyTorch >= 2.0
- transformers >= 4.35
- geoopt >= 0.5 (Riemannian optimization)
Quick Start
The simplest way to use cacheshrink is with compression_method="auto", which detects your model's attention type and picks the best method:
import torch
from cacheshrink import convert_to_mla, compute_perplexity
# Convert any supported model — auto-selects the best compression method
model, tokenizer = convert_to_mla(
"Qwen/Qwen2.5-7B",
compression_ratio=2.0,
compression_method="auto",
device="cuda",
dtype=torch.bfloat16,
use_calibration=True,
)
# Use it exactly like the original model
inputs = tokenizer("The theory of relativity", return_tensors="pt").to("cuda")
outputs = model.generate(inputs.input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
# Measure perplexity
ppl = compute_perplexity(model, tokenizer, eval_texts, max_length=512)
How It Works
Standard Attention vs MLA
In standard attention, hidden states are projected to full-dimensional keys and values, which are cached:
Standard Attention:
h -> W_k -> K (d_kv) ]
]-> cache stores K + V = 2 x d_kv per token
h -> W_v -> V (d_kv) ]
MLA adds a compression bottleneck — hidden states are first projected to a small latent space, and only this latent representation is cached. Keys and values are reconstructed on the fly:
MLA (Multi-Head Latent Attention):
h -> W_down_k -> c_k (d_latent) -> W_uk -> K (d_kv)
|
[cached: c_k] cache = 2 x d_latent per token
|
h -> W_down_v -> c_v (d_latent) -> W_uv -> V (d_kv)
The compression ratio is d_kv / d_latent. For example, with d_kv=512 and d_latent=256, you get 2x compression.
The Matrices
Each compressed attention layer has four learned matrices:
| Matrix | Shape | Role | Optimizer |
|---|---|---|---|
W_down_k |
(d_latent, d_model) | Compresses hidden states to key latent | AdamW |
W_down_v |
(d_latent, d_model) | Compresses hidden states to value latent | AdamW |
W_uk |
(d_kv, d_latent) | Decompresses key latent back to keys | RiemannianAdam |
W_uv |
(d_kv, d_latent) | Decompresses value latent back to values | RiemannianAdam |
The decompression matrices (W_uk, W_uv) are constrained to have orthonormal columns — they live on the Stiefel manifold. This constraint ensures:
- Stable reconstruction — orthonormal projections preserve geometric relationships
- Energy preservation — vector norms are maintained through decompression
- Clean initialization — SVD naturally produces orthonormal matrices
During training, geoopt's RiemannianAdam optimizer updates these matrices while automatically staying on the manifold. No explicit re-orthonormalization is needed.
Per-Layer vs Cross-Layer Compression
Per-layer (separate): Each layer has its own W_down_k, W_down_v, W_uk, W_uv. Works well for MHA models where each layer's KV space is high-dimensional (many KV heads).
Cross-layer (xKV): Groups of adjacent layers share W_uk and W_uv. Only W_down_k and W_down_v differ per layer. This exploits the fact that adjacent layers in GQA models have similar KV structure, so a shared basis can represent them all efficiently.
xKV Group (layers 4-7):
Shared: W_uk, W_uv (one pair for the whole group)
Per-layer: W_down_k, W_down_v (different for each layer)
Layer 4: h -> W_down_k_4 -> c_k -> [shared W_uk] -> K_4
Layer 5: h -> W_down_k_5 -> c_k -> [shared W_uk] -> K_5
Layer 6: h -> W_down_k_6 -> c_k -> [shared W_uk] -> K_6
Layer 7: h -> W_down_k_7 -> c_k -> [shared W_uk] -> K_7
Choosing a Compression Method
| Method | When to Use | Models |
|---|---|---|
"auto" |
Always safe. Auto-detects and picks the best method. | Any |
"separate" |
MHA models with many KV heads. Each layer compressed independently. | GPT-2, LLaMA 2 |
"xkv" |
GQA/MQA models with few KV heads. Shares basis across layer groups. | Qwen, Mistral, LLaMA 3 |
"joint" |
Experimental. Only useful if training from scratch. | N/A |
"decoupled_rope" |
Experimental. Preserves positional information separately. | N/A |
Why do GQA models need xKV? GQA models like Qwen2.5-7B have only 4 KV heads (d_kv=512). Per-layer SVD on a 512-dimensional space has limited redundancy to exploit. Cross-layer SVD stacks K/V from multiple layers, giving a much larger matrix with more structure to compress.
Why don't MHA models need xKV? MHA models like LLaMA-2 7B have 32 KV heads (d_kv=4096). Per-layer SVD already finds ample redundancy within each layer. Forcing layers to share a basis would be unnecessarily constraining.
Use compression_method="auto" and it handles this automatically.
Usage Guide
Converting MHA Models (GPT-2, LLaMA 2)
For MHA models, use "separate" (or "auto", which selects "separate" for MHA):
import torch
from cacheshrink import convert_to_mla
model, tokenizer = convert_to_mla(
"meta-llama/Llama-2-7b-hf",
compression_ratio=4.0, # 4x KV cache reduction
compression_method="auto", # auto-selects "separate" for MHA
device="cuda",
dtype=torch.float16,
use_calibration=True, # recommended: uses activation statistics
num_calibration_samples=128,
max_calibration_length=512,
store_original_weights=True, # needed for reconstruction loss training
verbose=True,
)
Converting GQA Models (Qwen, Mistral, LLaMA 3)
For GQA models, use "auto" or "xkv":
import torch
from cacheshrink import convert_to_mla
model, tokenizer = convert_to_mla(
"Qwen/Qwen2.5-7B",
compression_ratio=2.0, # 2x KV cache reduction
compression_method="auto", # auto-selects "xkv" for GQA
cross_layer_group_size=4, # 4 layers share decompression matrices
device="cuda",
dtype=torch.bfloat16,
use_calibration=True,
num_calibration_samples=256,
max_calibration_length=1024,
store_original_weights=True,
verbose=True,
)
Early Layer Skipping (xKV only)
Early transformer layers (layers 0-3 typically) capture low-level features where information is spread uniformly across dimensions, making compression lossier. You can skip these layers to preserve quality at a small cost to compression ratio:
model, tokenizer = convert_to_mla(
"Qwen/Qwen2.5-7B",
compression_ratio=2.0,
compression_method="auto",
cross_layer_group_size=4,
xkv_skip_early_layers=4, # don't compress layers 0-3
keep_early_layers_original=True, # keep them as original attention
device="cuda",
dtype=torch.bfloat16,
use_calibration=True,
store_original_weights=True,
)
What the parameters mean:
xkv_skip_early_layers=4— the first 4 layers are excluded from xKV compression groupskeep_early_layers_original=True— those skipped layers keep their original attention (no compression at all). IfFalse, they would get per-layer MLA compression instead.
Trade-off: Skipping 4 out of 28 layers reduces effective compression from 2.0x to ~1.75x, but significantly improves quality (see Benchmark Results).
These parameters are xKV-only. Passing them with non-xKV methods ("separate", "joint", "decoupled_rope") raises a ValueError.
Important: Use bfloat16, Not float16
Always use dtype=torch.bfloat16 (not torch.float16) when converting and evaluating models. Qwen and other large models have weight values that exceed float16's narrow dynamic range (max ~65,504), causing overflow to inf/NaN during inference. bfloat16 has the same exponent range as float32 and handles these values correctly.
# GOOD
model, tokenizer = convert_to_mla("Qwen/Qwen2.5-7B", dtype=torch.bfloat16, ...)
# BAD — will produce NaN logits and inf perplexity
model, tokenizer = convert_to_mla("Qwen/Qwen2.5-7B", dtype=torch.float16, ...)
This also applies when loading saved models:
model, tokenizer = load_mla_model("./my-model", device="cuda", dtype=torch.bfloat16)
Fine-tuning with Reconstruction Loss
Reconstruction loss directly optimizes the compression matrices to minimize the difference between original and reconstructed K/V projections. This is the recommended training approach because it doesn't require loading a teacher model:
from cacheshrink import MLATrainer
trainer = MLATrainer(
model=model,
tokenizer=tokenizer,
euclidean_lr=1e-5, # learning rate for W_down_k, W_down_v
riemannian_lr=1e-4, # learning rate for W_uk, W_uv (on Stiefel manifold)
use_distillation=False,
use_reconstruction_loss=True, # optimize K/V reconstruction directly
reconstruction_alpha=0.3, # blend: 70% LM loss + 30% reconstruction loss
)
stats = trainer.train(
train_texts, # list of strings or HuggingFace dataset
num_epochs=3,
batch_size=4,
max_length=512,
)
Requirements: The model must be converted with store_original_weights=True. This stores the original W_k and W_v weight matrices as frozen buffers so the reconstruction loss can compute what K/V should be.
How it works:
Total loss = (1 - alpha) x LM_loss + alpha x reconstruction_loss
reconstruction_loss = avg over layers of:
MSE(K_reconstructed, K_original) + MSE(V_reconstructed, V_original)
where:
K_original = h @ W_k_original.T + b_k (what the original model produces)
K_reconstructed = W_uk @ W_down_k(h) + b_k (what the compressed model produces)
Fine-tuning with Knowledge Distillation
Alternatively, train the compressed model to match the original model's output distribution. This requires 2x GPU memory (both models loaded simultaneously):
from cacheshrink import MLATrainer
trainer = MLATrainer(
model=model,
tokenizer=tokenizer,
euclidean_lr=1e-5,
riemannian_lr=1e-4,
use_distillation=True, # match teacher's output distribution
# teacher_model is auto-loaded from model.mla_config.model_name
# or pass teacher_model=your_model explicitly
)
stats = trainer.train(train_texts, num_epochs=3, batch_size=4, max_length=512)
How it works:
Total loss = alpha x KL(student || teacher) + (1 - alpha) x LM_loss
where:
student = softmax(compressed_model_logits / T)
teacher = softmax(original_model_logits / T)
T = temperature (default 2.0)
alpha = distillation weight (default 0.9)
When to use distillation vs reconstruction loss:
| Reconstruction Loss | Knowledge Distillation | |
|---|---|---|
| GPU memory | 1x model | 2x model (student + teacher) |
| Setup | Simpler (just store_original_weights=True) |
Need teacher model loaded |
| What it optimizes | K/V reconstruction fidelity | Output distribution match |
| Best for | Most cases | When output distribution matters more than K/V accuracy |
Evaluation
from cacheshrink import compute_perplexity, measure_cache_memory, generate_samples
# Perplexity on evaluation data
ppl = compute_perplexity(model, tokenizer, eval_texts, max_length=512, batch_size=1)
print(f"Perplexity: {ppl:.2f}")
# KV cache memory analysis
stats = measure_cache_memory(model, sequence_lengths=[512, 1024, 2048, 4096])
for seq_len, info in stats['per_sequence_length'].items():
print(f" {seq_len} tokens: {info['standard_cache_formatted']} -> {info['mla_cache_formatted']}")
# Text generation
outputs = generate_samples(
model, tokenizer,
prompts=["The theory of relativity states that"],
max_new_tokens=100,
temperature=0.7,
)
Saving and Loading
from cacheshrink import save_mla_model, load_mla_model
# Save (uses safetensors by default)
save_mla_model(model, tokenizer, "./my-compressed-model")
# Load
model, tokenizer = load_mla_model("./my-compressed-model", device="cuda", dtype=torch.bfloat16)
Saved files include the model weights, tokenizer, MLA config, and xKV group state (if applicable). Shared tensor references (xKV's shared W_uk/W_uv) are correctly preserved through save/load.
Parameter Reference
convert_to_mla()
| Parameter | Type | Default | Description |
|---|---|---|---|
model_name_or_path |
str | required | HuggingFace model name or local path |
compression_ratio |
float | 4.0 | Target compression ratio. d_latent = d_kv / ratio |
compression_method |
str | "separate" |
"auto", "separate", "xkv", "joint", "decoupled_rope" |
d_latent |
int | None | Override latent dimension (auto-computed from ratio if None) |
device |
str | "cuda" |
Device to load model on |
dtype |
torch.dtype | None | Model dtype (e.g., torch.bfloat16) |
use_calibration |
bool | True | Use activation statistics for SVD initialization |
calibration_dataset |
str | "wikitext" |
HuggingFace dataset for calibration |
calibration_dataset_subset |
str | "wikitext-2-raw-v1" |
Dataset subset name (second arg to load_dataset()) |
num_calibration_samples |
int | 128 | Number of calibration samples (more = better init) |
max_calibration_length |
int | 512 | Max sequence length for calibration |
store_original_weights |
bool | False | Store W_k/W_v for reconstruction loss training |
cross_layer_group_size |
int | 4 | Layers per xKV group (xKV only) |
xkv_skip_early_layers |
int | 0 | Number of early layers to exclude from xKV (xKV only) |
keep_early_layers_original |
bool | False | Keep skipped layers as original attention (xKV only) |
verbose |
bool | True | Print progress and stats |
MLATrainer()
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module | required | MLA-converted model |
tokenizer |
Tokenizer | required | HuggingFace tokenizer |
euclidean_lr |
float | 1e-5 | Learning rate for W_down_k, W_down_v |
riemannian_lr |
float | 1e-4 | Learning rate for W_uk, W_uv (Stiefel manifold) |
use_distillation |
bool | True | Use knowledge distillation (needs teacher model) |
use_reconstruction_loss |
bool | False | Use K/V reconstruction loss |
reconstruction_alpha |
float | 0.3 | Weight of reconstruction loss (0.0-1.0) |
teacher_model |
nn.Module | None | Teacher model for distillation (auto-loaded if None) |
use_distillation and use_reconstruction_loss are mutually exclusive.
trainer.train()
| Parameter | Type | Default | Description |
|---|---|---|---|
dataset |
list or Dataset | required | List of strings or HuggingFace dataset |
num_epochs |
int | 3 | Number of training epochs |
batch_size |
int | 4 | Training batch size |
max_length |
int | 512 | Maximum sequence length |
Training Details
What Gets Trained
Only the compression matrices are trained. All other model parameters (embeddings, layer norms, Q/O projections, FFN layers) are frozen:
Trainable:
W_down_k, W_down_v (per-layer, Euclidean params) -> AdamW optimizer
W_uk, W_uv (shared in xKV, Stiefel params) -> RiemannianAdam optimizer
Frozen:
Everything else (embeddings, attention Q/O, FFN, layer norms)
Typical trainable parameter counts:
| Model | Method | Trainable | Total | % |
|---|---|---|---|---|
| Qwen2.5-7B | xKV 2x | ~708M | 7.6B | 9.3% |
| Qwen2.5-7B | xKV 2x (skip 4) | ~826M | 7.6B | 10.9% |
| LLaMA-2 7B | separate 4x | ~537M | 6.7B | 8.0% |
Riemannian Optimization
The decompression matrices W_uk and W_uv must maintain orthonormal columns (Stiefel manifold constraint: W.T @ W = I). Standard gradient descent would violate this constraint.
geoopt's RiemannianAdam handles this by:
- Computing the Euclidean gradient
- Projecting it onto the tangent space of the Stiefel manifold
- Taking a step along a geodesic (retraction)
This keeps the matrices on the manifold at every step. After training, orthonormality errors are typically < 1e-6.
Reconstruction Loss vs LM Loss
The training loss is a weighted blend:
- LM loss (language modeling cross-entropy) — keeps the model's text generation quality
- Reconstruction loss (MSE on K/V) — directly improves compression fidelity
With reconstruction_alpha=0.3, the total loss is:
loss = 0.7 * LM_loss + 0.3 * reconstruction_loss
Higher alpha puts more weight on K/V reconstruction accuracy. Lower alpha preserves more of the original model's generation behavior. 0.3 is a good default.
Benchmark Results
Qwen2.5-7B with xKV 2x Compression
| Configuration | Baseline | After Conversion | After 3 Epochs | Degradation |
|---|---|---|---|---|
| xKV (skip 4 early layers) | 10.58 | 17.44 | 12.46 | +17.8% |
| xKV (all layers compressed) | 10.58 | 46.67 | 14.29 | +35.1% |
Training: 1000 WikiText-2 samples, 3 epochs, batch_size=4, reconstruction_alpha=0.3
Key observations:
- Skipping 4 early layers improves post-training perplexity by 1.8 points (12.46 vs 14.29)
- The quality gap with skipping is much smaller (17.8% vs 35.1% degradation from baseline)
- Effective compression is 1.75x with skip vs 2.0x without — a good trade-off
Mistral 7B with 4x Compression
| Metric | Original | After MLA | After Fine-tuning |
|---|---|---|---|
| Perplexity (WikiText-2) | 12.90 | 21.84 (+69%) | 13.00 (+0.8%) |
| KV Cache @ 2048 tokens | 512 MB | 128 MB | 128 MB |
Training: 1000 WikiText-2 samples, 3 epochs, batch_size=4, reconstruction_alpha=0.3. 98.9% of conversion degradation recovered.
LLaMA-2 7B with 16x Compression
| Metric | Original | After MLA | After Fine-tuning |
|---|---|---|---|
| Perplexity (WikiText-2) | 11.70 | 255.90 (+2088%) | 17.68 (+51.2%) |
| KV Cache @ 2048 tokens | 1.00 GB | 64 MB | 64 MB |
Training: 1000 WikiText-2 samples, 3 epochs, batch_size=4, reconstruction_alpha=0.3. Even at extreme 16x compression (4096 → 256 dims), fine-tuning recovers most quality — perplexity drops from 255.90 to 17.68.
Memory Savings
Qwen2.5-7B (2x compression, 28 layers):
| Sequence Length | Standard Cache | MLA Cache | Saved |
|---|---|---|---|
| 512 | 28 MB | 14 MB | 14 MB |
| 2048 | 112 MB | 56 MB | 56 MB |
| 4096 | 224 MB | 112 MB | 112 MB |
LLaMA-2 7B (16x compression, 32 layers):
| Sequence Length | Standard Cache | MLA Cache | Saved |
|---|---|---|---|
| 512 | 256 MB | 16 MB | 240 MB |
| 2048 | 1.00 GB | 64 MB | 960 MB |
| 4096 | 2.00 GB | 128 MB | 1.9 GB |
| 8192 | 4.00 GB | 256 MB | 3.75 GB |
Supported Models
Tested Open-Source Models
| Model | HuggingFace ID | Attention | KV Heads | Recommended Method |
|---|---|---|---|---|
| GPT-2 | openai-community/gpt2 |
MHA | 12 | separate |
| GPT-2 Medium | openai-community/gpt2-medium |
MHA | 16 | separate |
| GPT-2 Large | openai-community/gpt2-large |
MHA | 20 | separate |
| GPT-2 XL | openai-community/gpt2-xl |
MHA | 25 | separate |
| LLaMA 2 7B | meta-llama/Llama-2-7b-hf |
MHA | 32 | separate |
| LLaMA 2 13B | meta-llama/Llama-2-13b-hf |
MHA | 40 | separate |
| LLaMA 2 70B | meta-llama/Llama-2-70b-hf |
GQA | 8 | xkv |
| LLaMA 3 8B | meta-llama/Meta-Llama-3-8B |
GQA | 8 | xkv |
| LLaMA 3 70B | meta-llama/Meta-Llama-3-70B |
GQA | 8 | xkv |
| LLaMA 3.1 8B | meta-llama/Llama-3.1-8B |
GQA | 8 | xkv |
| LLaMA 3.1 70B | meta-llama/Llama-3.1-70B |
GQA | 8 | xkv |
| Mistral 7B v0.1 | mistralai/Mistral-7B-v0.1 |
GQA | 8 | xkv |
| Mistral 7B v0.3 | mistralai/Mistral-7B-v0.3 |
GQA | 8 | xkv |
| Mixtral 8x7B | mistralai/Mixtral-8x7B-v0.1 |
GQA | 8 | xkv |
| Qwen2 7B | Qwen/Qwen2-7B |
GQA | 4 | xkv |
| Qwen2.5 7B | Qwen/Qwen2.5-7B |
GQA | 4 | xkv |
| Qwen2.5 14B | Qwen/Qwen2.5-14B |
GQA | 4 | xkv |
| Qwen2.5 32B | Qwen/Qwen2.5-32B |
GQA | 8 | xkv |
| Qwen2.5 72B | Qwen/Qwen2.5-72B |
GQA | 8 | xkv |
Instruction-tuned variants (e.g., Llama-2-7b-chat-hf, Mistral-7B-Instruct-v0.3, Qwen2.5-7B-Instruct) are also supported — they share the same architecture as their base models.
Model Family Summary
| Family | Handler | Attention Types | Notes |
|---|---|---|---|
| GPT-2 | gpt2 |
MHA | Combined QKV projection, learned position embeddings |
| LLaMA | llama |
MHA (2-7B/13B), GQA (2-70B, 3.x) | Separate Q/K/V/O projections, RoPE |
| Mistral | mistral |
GQA | Same as LLaMA with sliding window attention |
| Qwen/Qwen2 | qwen |
GQA | Has K/V biases (handled automatically) |
| DeepSeek V2 | N/A | Native MLA | Already uses MLA, not supported |
Not Supported
Models that already use compressed KV caching (e.g., DeepSeek V2/V3 with native MLA) are detected and rejected with a clear error message. Models with architectures not listed above will raise an Unknown model type error — see Adding New Models to add support.
Attention type glossary:
- MHA (Multi-Head Attention): n_kv_heads == n_heads. Every query head has its own KV head. Large d_kv.
- GQA (Grouped Query Attention): 1 < n_kv_heads < n_heads. Multiple query heads share each KV head. Smaller d_kv.
- MQA (Multi-Query Attention): n_kv_heads == 1. All query heads share a single KV head. Smallest d_kv.
Adding New Models
Extend ModelHandler to support additional architectures:
from cacheshrink.model_handlers.base import ModelHandler
class MyModelHandler(ModelHandler):
def get_num_layers(self) -> int:
return len(self.model.my_layers)
def get_attention_module(self, layer_idx: int):
return self.model.my_layers[layer_idx].attention
def extract_qkv_weights(self, layer_idx: int):
attn = self.get_attention_module(layer_idx)
return attn.W_q.data, attn.W_k.data, attn.W_v.data, attn.W_o.data
def extract_qkv_biases(self, layer_idx: int):
attn = self.get_attention_module(layer_idx)
return attn.b_q, attn.b_k, attn.b_v, attn.b_o # None if no biases
def replace_attention(self, layer_idx: int, new_attn):
self.model.my_layers[layer_idx].attention = new_attn
Register it in model_handlers/__init__.py to make it available via get_handler().
Roadmap
- HuggingFace Hub integration — Save and load compressed models with
AutoModelForCausalLM.from_pretrained()viatrust_remote_code=True. Publish compressed models to the Hub for anyone to use without installing cacheshrink. - vLLM support — Custom model plugin for vLLM so compressed models can serve with PagedAttention, continuous batching, and all the production features vLLM provides.
- Custom Triton kernels — Fused compress/decompress + attention kernels to eliminate the overhead of reconstructing K/V into a separate buffer. Directly compute attention from the latent representations.
- Combining with KV cache quantization — Stack dimensional compression (cacheshrink) with precision reduction (FP8/INT8) for even higher compression ratios.
License
Apache 2.0
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cacheshrink-0.1.3.tar.gz.
File metadata
- Download URL: cacheshrink-0.1.3.tar.gz
- Upload date:
- Size: 96.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a528cf8693e223c73fdf525b58dea786667f19d7126ddc73c0405b38d3821889
|
|
| MD5 |
8a24693943c129cb82b0e2757cc5ad02
|
|
| BLAKE2b-256 |
2bdfc083240b7adb0006af9ece78f8b54b821cf881923de091d50f95160f2f61
|
File details
Details for the file cacheshrink-0.1.3-py3-none-any.whl.
File metadata
- Download URL: cacheshrink-0.1.3-py3-none-any.whl
- Upload date:
- Size: 84.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
49af2be990be9e0df55d67211102757dab0213720a944e2bca7a65b44e1cd37d
|
|
| MD5 |
d31b9c1e3e139684b2e3c23bbe923033
|
|
| BLAKE2b-256 |
c017fbe4fe4be9ac30db176427cad0589e61f9bc6330d86b53c8f8e1f813cbf6
|