SAM-Gate: Semantic-Aware Memory Gate for heterogeneous KV-cache compression in transformer models
Project description
SAM-Gate
Semantic-Aware Memory Gate — adaptive KV-cache compression for transformer models, guided by the intrinsic geometry of the attention flow.
SAM-Gate estimates the harmonic curvature f(t) per layer at inference time and assigns each layer to a compression regime:
| Regime | Condition | KV bits | Heads | Window |
|---|---|---|---|---|
| Flat | f(t) < f_flat_max | int4 | 40% | short |
| Transition | f_flat_max ≤ f(t) < f_obs_min | int8 | 70% | medium |
| Obstructed | f(t) ≥ f_obs_min | fp16 | 100% | long |
The KV window is fixed regardless of context length — memory is O(1), not O(T).
Benchmark results (Qwen2.5-3B-Instruct, RTX 4070 8GB)
| Engine | ctx=512 KV | ctx=1024 KV | TPS (ctx=512) |
|---|---|---|---|
| Baseline HF | 18828 KB | 37260 KB | 30.4 |
| SAM-Gate | 4608 KB | 4608 KB | 8.5 |
| RCI (int8) | 9702 KB | 19206 KB | 9.2 |
SAM-Gate KV footprint does not grow with context. At ctx=1024, baseline uses 8x more memory than SAM-Gate.
Requirements
- Python >= 3.10
- PyTorch >= 2.1.0 (install manually: https://pytorch.org/get-started/locally/)
- CUDA-capable GPU recommended
Install
Install PyTorch first for your CUDA / platform (see Requirements), then:
# Base install (Windows / Mac / Linux — uses PyTorch SDPA)
pip install sam-gate
# Linux / WSL with NVIDIA GPU — enables Flash Attention (2-3x TPS gain)
pip install sam-gate[flash]
# With SpectralRCI support
pip install sam-gate[spectral]
Note:
flash-attnrequires Linux or WSL2 with CUDA 11.7+. On Windows, SAM-Gate automatically falls back to PyTorch SDPA — fully functional, ~50% lower TPS at long contexts.
Quick start
from transformers import AutoModelForCausalLM, AutoTokenizer
from sam_gate import attach_semantic_hooks, SAMConfig
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
cfg = SAMConfig(
max_ctx_flat = 64,
max_ctx_trans = 128,
max_ctx_obs = 128,
)
kv_caches = {}
handles = attach_semantic_hooks(model, kv_caches, cfg=cfg)
inputs = tokenizer("Explain how attention works.", return_tensors="pt").to("cuda")
out = model.generate(**inputs, max_new_tokens=200, use_cache=False)
for h in handles:
h.remove()
print(tokenizer.decode(out[0], skip_special_tokens=True))
print(f"KV used: {sum(c.ram_bytes() for c in kv_caches.values()) / 1024:.1f} KB")
Calibration
Thresholds f_flat_max and f_obs_min are model-specific. Before using
SAM-Gate on a new model, run calibration to observe the real f(t) distribution:
python -m sam_gate.sam --model Qwen/Qwen2.5-3B-Instruct --device cuda --calibrate --verbose
Then set f_flat_max and f_obs_min in SAMConfig based on the observed
values per layer.
Configuration reference
from sam_gate import SAMConfig
cfg = SAMConfig(
tau = 0.05, # harmonic kernel threshold
flat_bits = 4, # int4 in flat regime
trans_bits = 8, # int8 in transition
obs_bits = 16, # fp16 in obstructed
flat_heads = 0.5, # fraction of query heads in flat regime
trans_heads = 0.75, # fraction in transition
f_flat_max = 1e-2, # calibrated for Qwen2.5-3B-Instruct
f_obs_min = 50.0, # calibrated for Qwen2.5-3B-Instruct
max_ctx_flat = 64, # KV window in flat regime (tokens)
max_ctx_trans = 128, # KV window in transition (tokens)
max_ctx_obs = 128, # KV window in obstructed (tokens)
)
How it works
SAM-Gate estimates the Morse functional f(t) via Hutchinson probing — O(K·H·d) instead of O(d³) eigendecomposition:
Δ_t = Σ_i (T_i - I)*(T_i - I) ← semantic Laplacian
f(t) = Σ_{i<j} ||[T_i, T_j]||²_F ← harmonic curvature
At decode time, the policy is cached per layer — zero overhead from the prober after prefill.
The KV cache uses a dense ring buffer (fp16, fixed capacity) initialized
once from the quantized prefill chunks. At each decode step, reconstruct_for_attn
is O(1) — a single slice, no dequantization loop.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sam_gate-0.1.3.tar.gz.
File metadata
- Download URL: sam_gate-0.1.3.tar.gz
- Upload date:
- Size: 68.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
659ac8af2c4daa21a0e415b1d502454613194dfd92ba6c58114305408e335d31
|
|
| MD5 |
8ef448aad7be9a930199b16b0f5a5130
|
|
| BLAKE2b-256 |
bd40601633fcf183b83da40ed16d50c9e23b6c321379f515f4a49ffc6fa3af38
|
File details
Details for the file sam_gate-0.1.3-py3-none-any.whl.
File metadata
- Download URL: sam_gate-0.1.3-py3-none-any.whl
- Upload date:
- Size: 63.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6149168385514f2f2060386d72e7f2a03d2b8655dd25cffcd22d8fa1cbf3c0c0
|
|
| MD5 |
e199c6c2918b5522998ec41c792ba9ae
|
|
| BLAKE2b-256 |
af07af00fa2b3d9d4c156628e9bd64bbf0044fb854f770ed55ac36035b7e3430
|