FADE: Frequency-Adaptive Decay Encoding — attention-aware tiered KV cache compression for LLM inference.
Project description
FADE
Frequency-Adaptive Decay Encoding — drop-in KV cache compression for HuggingFace transformers. Shrinks the KV cache 3–5× with near-baseline quality.
from fade import FadeConfig, create_tiered_cache
cache = create_tiered_cache(model, config=FadeConfig.safe())
out = model.generate(input_ids, past_key_values=cache, max_new_tokens=256)
Works with model.generate() — greedy, sampling, beam search. No manual decode loop needed.
How it works
Tokens live in tiers based on age and attention importance:
| Tier | What's stored | When |
|---|---|---|
| FP16 | Full precision | First N_SINK tokens + last RECENT_WINDOW tokens |
| INT4 | Bit-packed 4-bit | Middle-aged tokens (the bulk of the cache) |
| INT2 | Grouped 2-bit | Optional deeper compression (lossy) |
| PQ | Product-quantized codes | ~2 bits/element via trained codebook (Phase 3) |
| Evicted | Nothing | Dropped when INT4_BUDGET is finite |
When tokens are evicted, surviving K tensors are un-RoPE'd at old positions and re-RoPE'd with contiguous StreamingLLM positions.
Install
python -m venv .venv
.\.venv\Scripts\Activate.ps1
pip install torch # match your CUDA version: https://pytorch.org/get-started/locally/
pip install -e ".[dev]"
Optional extras: [cuda] (accelerate), [eval] (datasets), [codebook] (scikit-learn for PQ).
Quick start
Presets
from fade import FadeConfig, create_tiered_cache
# Safe: ~3-4x compression, 100% greedy match. No eviction.
cache = create_tiered_cache(model, config=FadeConfig.safe())
# Balanced: ~5x compression with H2O eviction.
cache = create_tiered_cache(model, config=FadeConfig.balanced())
# Aggressive: ~7-8x compression. Validate on your workload first.
cache = create_tiered_cache(model, config=FadeConfig.aggressive())
Custom config
cache = create_tiered_cache(model, config=FadeConfig(
phase="2",
n_sink=4,
recent_window=64,
int4_budget=400,
eviction_policy="h2o", # "h2o", "ema", "position", or "learned"
middle_k_bits=4, # K stays INT4 (outlier-sensitive)
middle_v_bits=2, # V at INT2 (~30% more compression)
))
Manual decode with tier reassignment
from fade.patch import forward_with_tracking, load_model
from fade.policy import reassign_tiers
from fade.tracker import AttentionTracker
model, tokenizer = load_model("Qwen/Qwen2.5-3B-Instruct", attn_impl="auto", need_attentions=True)
cache = create_tiered_cache(model, config=FadeConfig.balanced())
tracker = AttentionTracker(num_layers=model.config.num_hidden_layers)
out = forward_with_tracking(model, input_ids, cache, tracker=tracker)
for step in range(max_tokens):
out = forward_with_tracking(model, next_token, cache, tracker=tracker)
if (step + 1) % 64 == 0:
reassign_tiers(cache, tracker, model.config.num_hidden_layers)
Eviction policies
| Policy | Quality | Speed | Needs attention? |
|---|---|---|---|
h2o |
Best | Normal | Yes (prefill only) |
ema |
Good | Normal | Yes (decode only) |
position |
Fair | Fast | No |
learned |
Good* | Fast | No |
*Learned policy requires a trained checkpoint: python scripts/train_eviction_mlp.py
Supported models
FADE auto-detects the RoPE scheme from the model config:
- Qwen2 / Qwen3 — vanilla RoPE, GQA
- Llama / Llama-3.1 — vanilla + frequency-dependent scaling
- Mistral — vanilla RoPE, sliding-window
- Phi-3 — vanilla RoPE
- Gemma-2 — vanilla RoPE
- Gemma 4 — proportional RoPE with
partial_rotary_factor+ per-layer-type dispatch - Falcon — ALiBi (non-RoPE; re-RoPE is a no-op)
- Qwen 3.5 / 3.6 — hybrid DeltaNet + softmax attention. FADE auto-detects
layer_typesand skips DeltaNet layers (only full-attention layers are tiered).
RoPE scaling types: linear, llama3, ntk, dynamic, yarn, proportional. Non-RoPE models (ALiBi, Bloom, MPT) work via the NoRope sentinel.
Batching
Two modes:
- Shared-tier (default): all rows share positions and tier decisions. For lockstep decoding.
- Per-sequence (
apply_tier_assignment_per_sequence): each row gets independent[B, S]tiers. For continuous-batching where sequences diverge.
Performance
- Pre-allocated FP16 buffer — doubling buffer eliminates
torch.caton every decode step. - torch.compile —
cache.enable_compile()wraps_materializebetween graph-break boundaries. - Triton INT4 kernel —
int4_sdpa(q, k_packed, k_scale, v_packed, v_scale, force_triton=True)runs fused INT4 unpack on CUDA. Exact parity validated on RTX 3060. - Dequant-cache age eviction —
cache.max_dequant_age = Nperiodically refreshes cached dequant buffers. - Benchmarks —
python benchmarks/tps.py(decode throughput),python benchmarks/divergence.py(quality).
Checkpointing
sd = cache.cache_state_dict()
torch.save(sd, "cache.pt")
cache.load_cache_state_dict(torch.load("cache.pt"))
Observability
from fade.telemetry import JsonlExporter, attach_telemetry
attach_telemetry(cache, JsonlExporter("events.jsonl"))
Debug dump: cache.dump_debug("snapshot.json")
PQ codebook
from fade.codebook import PQCodebook
cb = PQCodebook.train(calibration_vectors, sub_dim=32, num_centroids=256)
cache.set_codebooks(cb) # enables TIER_PQ in tier assignment
Train codebooks from a real model: python scripts/train_codebook.py
Results
| Config | Model | KV cache | Compression |
|---|---|---|---|
| Phase 1-A | Qwen2.5-0.5B, 782 tok | 4.0 MiB | 67% smaller, 100% token match |
| Phase 2 H2O | Qwen2.5-3B, 595 tok | 6.3 MiB | 79% smaller, coherent output |
Project layout
fade/
cache.py # TieredKVCache with 5 tiers (FP16/INT4/INT2/PQ/evict)
config.py # FadeConfig with presets
quant.py # INT4/INT2 quantization + bit-packing
rope.py # 7 RoPE schemes incl. Gemma 4 proportional
policy.py # Tier assignment: h2o, ema, position
learned_policy.py # Learned eviction MLP
tracker.py # AttentionTracker (per-layer EMA)
patch.py # load_model, create_tiered_cache, forward_with_tracking
codebook.py # PQ codebook train/encode/decode
telemetry.py # Structured telemetry + exporters
kernels/ # Triton INT4 unpack kernel + torch fallback
serving/ # vLLM / SGLang adapter stubs
eval/ # Perplexity, needle, quality suite
examples/ # quickstart.py
experiments/ # run_baseline.py, run_tiered.py
benchmarks/ # tps.py, divergence.py
scripts/ # train_eviction_mlp.py, train_codebook.py
tests/ # 136 tests, all CPU, no downloads
Gotchas
- Attention impl:
eageronly needed for H2O prefill. Useload_model(attn_impl="auto"). - Transformers version: verified on 4.45 and 5.3. Weekly canary CI runs against
transformers@main. - Memory: use
cache.compressed_storage_bytes(), notnvidia-smi. - RoPE precision: all math in float32, cast through model dtype to match rounding.
- Hybrid models: Qwen 3.5/3.6 DeltaNet layers are auto-skipped — only full-attention layers are tiered.
- Triton kernel: opt-in via
force_triton=True. Runcheck_parity()on your hardware first.
License
Apache-2.0. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fade_kv-0.2.0.tar.gz.
File metadata
- Download URL: fade_kv-0.2.0.tar.gz
- Upload date:
- Size: 75.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b09fd2beb6066a34bc76b826b2bd9f7030184f1df1a7c3980e6eaf22c9061d9a
|
|
| MD5 |
604904e6c41af604c2aa98a3722fc108
|
|
| BLAKE2b-256 |
8f24621a24a5cb9f1dcb51199591ab1cdcf300296516c2e028569689c9e8b023
|
Provenance
The following attestation bundles were made for fade_kv-0.2.0.tar.gz:
Publisher:
ci.yml on Omodaka9375/fade
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
fade_kv-0.2.0.tar.gz -
Subject digest:
b09fd2beb6066a34bc76b826b2bd9f7030184f1df1a7c3980e6eaf22c9061d9a - Sigstore transparency entry: 1370333996
- Sigstore integration time:
-
Permalink:
Omodaka9375/fade@cccbf894eac808350dacdbcda832dfd9ea69a263 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/Omodaka9375
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@cccbf894eac808350dacdbcda832dfd9ea69a263 -
Trigger Event:
push
-
Statement type:
File details
Details for the file fade_kv-0.2.0-py3-none-any.whl.
File metadata
- Download URL: fade_kv-0.2.0-py3-none-any.whl
- Upload date:
- Size: 55.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f5b721071aa8e0cf859b9e53e70f1aeab2566a91ec8fee873df3348b6addceee
|
|
| MD5 |
cb8b6f4930cfe337ddc0f68c356aeaaf
|
|
| BLAKE2b-256 |
1c54ea6136c36eb7051fa92d88debf01ac7c1d73ce7d99c9395b03679e67b5b1
|
Provenance
The following attestation bundles were made for fade_kv-0.2.0-py3-none-any.whl:
Publisher:
ci.yml on Omodaka9375/fade
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
fade_kv-0.2.0-py3-none-any.whl -
Subject digest:
f5b721071aa8e0cf859b9e53e70f1aeab2566a91ec8fee873df3348b6addceee - Sigstore transparency entry: 1370334111
- Sigstore integration time:
-
Permalink:
Omodaka9375/fade@cccbf894eac808350dacdbcda832dfd9ea69a263 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/Omodaka9375
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@cccbf894eac808350dacdbcda832dfd9ea69a263 -
Trigger Event:
push
-
Statement type: