Skip to main content

SAM-Gate: Semantic-Aware Memory Gate for heterogeneous KV-cache compression in transformer models

Project description

SAM-Gate

Semantic-Aware Memory Gate — adaptive KV-cache compression for transformer models, guided by the intrinsic geometry of the attention flow.

SAM-Gate estimates the harmonic curvature f(t) per layer at inference time and assigns each layer to a compression regime:

Regime Condition KV bits Heads Window
Flat f(t) < f_flat_max int4 40% short
Transition f_flat_max ≤ f(t) < f_obs_min int8 70% medium
Obstructed f(t) ≥ f_obs_min fp16 100% long

The KV window is fixed regardless of context length — memory is O(1), not O(T).

Benchmark results (Qwen2.5-3B-Instruct, RTX 4070 8GB)

Engine ctx=512 KV ctx=1024 KV TPS (ctx=512)
Baseline HF 18828 KB 37260 KB 30.4
SAM-Gate 4608 KB 4608 KB 8.5
RCI (int8) 9702 KB 19206 KB 9.2

SAM-Gate KV footprint does not grow with context. At ctx=1024, baseline uses 8x more memory than SAM-Gate.

Install

# Base install (Windows / Mac / Linux — uses PyTorch SDPA)
pip install sam-gate

# Linux / WSL with NVIDIA GPU — enables Flash Attention (2-3x TPS gain)
pip install sam-gate[flash]

# With SpectralRCI support
pip install sam-gate[spectral]

Note: flash-attn requires Linux or WSL2 with CUDA 11.7+. On Windows, SAM-Gate automatically falls back to PyTorch SDPA — fully functional, ~50% lower TPS at long contexts.

Quick start

from transformers import AutoModelForCausalLM, AutoTokenizer
from sam_gate import attach_semantic_hooks, SAMConfig

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")

cfg = SAMConfig(
    max_ctx_flat  = 64,
    max_ctx_trans = 128,
    max_ctx_obs   = 128,
)

kv_caches = {}
handles = attach_semantic_hooks(model, kv_caches, cfg=cfg)

inputs = tokenizer("Explain how attention works.", return_tensors="pt").to("cuda")
out = model.generate(**inputs, max_new_tokens=200, use_cache=False)

for h in handles:
    h.remove()

print(tokenizer.decode(out[0], skip_special_tokens=True))
print(f"KV used: {sum(c.ram_bytes() for c in kv_caches.values()) / 1024:.1f} KB")

Calibration

Thresholds f_flat_max and f_obs_min are model-specific. Before using SAM-Gate on a new model, run calibration to observe the real f(t) distribution:

python -m sam_gate.sam --model Qwen/Qwen2.5-3B-Instruct --device cuda --calibrate --verbose

Then set f_flat_max and f_obs_min in SAMConfig based on the observed values per layer.

Configuration reference

from sam_gate import SAMConfig

cfg = SAMConfig(
    tau           = 0.05,    # harmonic kernel threshold
    flat_bits     = 4,       # int4 in flat regime
    trans_bits    = 8,       # int8 in transition
    obs_bits      = 16,      # fp16 in obstructed
    flat_heads    = 0.5,     # fraction of query heads in flat regime
    trans_heads   = 0.75,    # fraction in transition
    f_flat_max    = 1e-2,    # calibrated for Qwen2.5-3B-Instruct
    f_obs_min     = 50.0,    # calibrated for Qwen2.5-3B-Instruct
    max_ctx_flat  = 64,      # KV window in flat regime (tokens)
    max_ctx_trans = 128,     # KV window in transition (tokens)
    max_ctx_obs   = 128,     # KV window in obstructed (tokens)
)

How it works

SAM-Gate estimates the Morse functional f(t) via Hutchinson probing — O(K·H·d) instead of O(d³) eigendecomposition:

Δ_t  = Σ_i (T_i - I)*(T_i - I)     ← semantic Laplacian
f(t) = Σ_{i<j} ||[T_i, T_j]||²_F   ← harmonic curvature

At decode time, the policy is cached per layer — zero overhead from the prober after prefill.

The KV cache uses a dense ring buffer (fp16, fixed capacity) initialized once from the quantized prefill chunks. At each decode step, reconstruct_for_attn is O(1) — a single slice, no dequantization loop.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sam_gate-0.1.2.tar.gz (68.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sam_gate-0.1.2-py3-none-any.whl (63.3 kB view details)

Uploaded Python 3

File details

Details for the file sam_gate-0.1.2.tar.gz.

File metadata

  • Download URL: sam_gate-0.1.2.tar.gz
  • Upload date:
  • Size: 68.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.8

File hashes

Hashes for sam_gate-0.1.2.tar.gz
Algorithm Hash digest
SHA256 1078e78491042e0cd27792af2e8fbd4f65678048bddd9e3f53d1f097c37897db
MD5 d2a112bc77f5823b168c39df7ae59a57
BLAKE2b-256 fa031423cb271dd770d4317d871f17eab66711dfd34c3446e26a07b2b2f501eb

See more details on using hashes here.

File details

Details for the file sam_gate-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: sam_gate-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 63.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.8

File hashes

Hashes for sam_gate-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1efbb854c0a3b49638a69d74d52b781fd198aa1236892b513cf33d1271ebd0a7
MD5 f2f6053636a05a21161fe59d90a81f59
BLAKE2b-256 febc03f46b0da61016e9c522d1479db42bb5ac88e1e4b35312a53d5f465aed32

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page