Skip to main content

Efficiently compress the KV cache of any pretrained transformer

Project description

PyPI version License Colab example notebook Hugging Face Space Blog post

kvpress

Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. kvpress implements multiple KV cache compression methods and benchmarks using 🤗 transformers, aiming to simplify the development of new methods for researchers and developers in this field.

Installation

pip install kvpress

If possible, install flash attention:

pip install flash-attn --no-build-isolation

Usage

kvpress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a compression_ratio attribute that measures the compression of the cache. The easiest way to use a press is through our custom KVPressTextGenerationPipeline. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you:

from transformers import pipeline
from kvpress import ExpectedAttentionPress

device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context"  # optional

press = ExpectedAttentionPress(compression_ratio=0.5)
answer = pipe(context, question=question, press=press)["answer"]

In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the Wikipedia notebook demo for a more detailed example (also available on Colab here).

[!IMPORTANT]
We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems.

[!NOTE]
Use model_kwargs={"attn_implementation":"flash_attention_2"} to enable flash attention. To use the press ObservedAttentionPress, you need to specify model_kwargs={"attn_implementation":"eager"} as this press requires to materialize the attention weights

Contributing

We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the new_press.ipynb notebook for a step-by-step guide.

Available presses

All current presses are training free and inherit from BasePress (source).

Several presses inherit from ScorerPress (source) and rely on a score to prune the KV pairs with lowest importance:

  • RandomPress (source): random score
  • KnormPress (source, paper): inverse norm of the key
  • SnapKVPress (source, paper): average attention weight of the last queries
  • ExpectedAttentionPress (source, notebook): expected attention weight during the generation phase
  • StreamingLLMPress (source, paper): keep only the initial and recent tokens
  • TOVAPress (source, paper): attention weight of the last query averaged across heads
  • ObservedAttentionPress (source, paper): average attention weight observed during in pre-filling phase
  • QFilterPress (source, paper): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.

Some presses rely on a different logic:

  • ThinKPress (source, paper): compress the dimensions of the keys based on the channel attention score on the last queries
  • SimLayerKVPress (source, paper): identify "lazy" layers, and apply the StreamingLLM approach to them
  • DuoAttentionPress (source, paper): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach)

Finally we provide wrapper presses that can be combined with other presses:

  • AdaKVPress (source, paper): prune bottom scores of any ScorerPress but across all heads, achieving head-wise compressions
  • PerLayerCompressionPress (source): compress each layer with a different compression ratio (experimental)
  • ComposedPress (source): compose multiple presses together by chaining their forward hooks
  • KeyRerotationPress (source): rerotate pruned keys to have continuous RoPE embeddings
  • ChunkKVPress (source, paper): compresses by selecting important chunks, preserving semantic coherence
  • ChunkPress (source, paper): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
  • CriticalKVPress and CriticalAdaKVPress (source, paper): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.

For a detailed list of existing KV cache compression methods, check Awesome-KV-Cache-Compression or Awesome-LLM-Compression

Evaluation

The speed_and_memory.ipynb notebook can help you to measure peak memory usage and total time gain.

memory

We provide a simple CLI to evaluate the performance of the different presses on several long-context datasets. Below we report the average performance on the RULER dataset with 4k context length for different presses.

RULER

Please refer to the evaluation directory for more details and results.

Quantization

We support KV cache quantization through the transformers QuantizedCache class (see HF blog post). To use it, simply pass a cache object to your pipeline:

from transformers import QuantizedCacheConfig, QuantoQuantizedCache

config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)

pipe(..., cache=cache)

By default, the DynamicCache is used (no quantization).

[!IMPORTANT]
To use the QuantizedCache, you need to install additional dependencies (e.g. pip install optimum-quanto).

FAQ

Which models are supported ?

Some presses depend on the model architecture (e.g. ExpectedAttentionPress or SnapKVPress) hence they might not work with all models. We tested support for LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM and Qwen2ForCausalLM but many other models might be supported out of the box because their implementation is often similar in transformers.

How to run inference on multiple GPUs ?

kvpress supports multi-GPU inference through accelerate:

pipe = pipeline("kv-press-text-generation", model=model, device_map="auto")

What are the memory and throughput gains ?

Memory usage should be reduced by around compression_ratio * kv_cache_size. As the KV cache is smaller, decoding should also be faster. You can measure peak memory usage gain and total time gain using this notebook.

How does a press work ?

A press registers a forward hook (press.forward_hook method) to each attention layer during the pre-filling phase. Registration can be applied using the press as a context manager (press.__call__ method):

import torch
from transformers import AutoModelForCausalLM
from kvpress import KnormPress

device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(ckpt).to(device)
press = KnormPress(compression_ratio=0.4)

inputs = model.dummy_inputs["input_ids"].to(device)

with torch.no_grad():
    print(model(inputs).past_key_values[0][0].shape)
    # torch.Size([3, 8, 5, 128])
    
with torch.no_grad(), press(model):
    print(model(inputs).past_key_values[0][0].shape)
    # torch.Size([3, 8, 3, 128])

Why not using model.generate ?

In fact you can use model.generate with a press by using the press as a context manager:

with press(model):
    outputs = model.generate(inputs)

However, the generate method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (e.g. for use cases such as chat or document question answering). Finally the generate method does not allow to provide generation for multiple questions at once.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kvpress-0.2.4.tar.gz (28.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

kvpress-0.2.4-py3-none-any.whl (40.8 kB view details)

Uploaded Python 3

File details

Details for the file kvpress-0.2.4.tar.gz.

File metadata

  • Download URL: kvpress-0.2.4.tar.gz
  • Upload date:
  • Size: 28.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for kvpress-0.2.4.tar.gz
Algorithm Hash digest
SHA256 75e221c5ed463f689198595c196522369fab46b46ff04309b59567ccf783ce1d
MD5 3d4321f91b9e5f18fc86c897d8230d04
BLAKE2b-256 2eed34f6a01f48eff916238b57f77fddac2d83ec1fc7ec18f90e837ec30f0826

See more details on using hashes here.

File details

Details for the file kvpress-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: kvpress-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 40.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for kvpress-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 96d1b074efb983f2287c04eb2349398b9d55e5c4e80e295b948f6a9cfa73b72b
MD5 5bd3ccba9554b7f0e8b6e6a6e6e33c0a
BLAKE2b-256 84aa8cca1767d3341dd20b2d4fd976268a3875e6f49a41a3787b324d08aa36db

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page