Skip to main content

SAM-Gate: Semantic-Aware Memory Gate for heterogeneous KV-cache compression in transformer models

Project description

SAM-Gate

Semantic-Aware Memory Gate — adaptive KV-cache compression for transformer models, guided by the intrinsic geometry of the attention flow.

SAM-Gate estimates the harmonic curvature f(t) per layer at inference time and assigns each layer to a compression regime:

Regime Condition KV bits Heads Window
Flat f(t) < f_flat_max int4 40% short
Transition f_flat_max ≤ f(t) < f_obs_min int8 70% medium
Obstructed f(t) ≥ f_obs_min fp16 100% long

The KV window is fixed regardless of context length — memory is O(1), not O(T).

Benchmark results (Qwen2.5-3B-Instruct, RTX 4070 8GB)

Engine ctx=512 KV ctx=1024 KV TPS (ctx=512)
Baseline HF 18828 KB 37260 KB 30.4
SAM-Gate 4608 KB 4608 KB 8.5
RCI (int8) 9702 KB 19206 KB 9.2

SAM-Gate KV footprint does not grow with context. At ctx=1024, baseline uses 8x more memory than SAM-Gate.

Requirements

Install

Install PyTorch first for your CUDA / platform (see Requirements), then:

# Base install (Windows / Mac / Linux — uses PyTorch SDPA)
pip install sam-gate

# Linux / WSL with NVIDIA GPU — enables Flash Attention (2-3x TPS gain)
pip install sam-gate[flash]

# With SpectralRCI support
pip install sam-gate[spectral]

Note: flash-attn requires Linux or WSL2 with CUDA 11.7+. On Windows, SAM-Gate automatically falls back to PyTorch SDPA — fully functional, ~50% lower TPS at long contexts.

Quick start

from transformers import AutoModelForCausalLM, AutoTokenizer
from sam_gate import attach_semantic_hooks, SAMConfig

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")

cfg = SAMConfig(
    max_ctx_flat  = 64,
    max_ctx_trans = 128,
    max_ctx_obs   = 128,
)

kv_caches = {}
handles = attach_semantic_hooks(model, kv_caches, cfg=cfg)

inputs = tokenizer("Explain how attention works.", return_tensors="pt").to("cuda")
out = model.generate(**inputs, max_new_tokens=200, use_cache=False)

for h in handles:
    h.remove()

print(tokenizer.decode(out[0], skip_special_tokens=True))
print(f"KV used: {sum(c.ram_bytes() for c in kv_caches.values()) / 1024:.1f} KB")

Calibration

Thresholds f_flat_max and f_obs_min are model-specific. Before using SAM-Gate on a new model, run calibration to observe the real f(t) distribution:

python -m sam_gate.sam --model Qwen/Qwen2.5-3B-Instruct --device cuda --calibrate --verbose

Then set f_flat_max and f_obs_min in SAMConfig based on the observed values per layer.

Configuration reference

from sam_gate import SAMConfig

cfg = SAMConfig(
    tau           = 0.05,    # harmonic kernel threshold
    flat_bits     = 4,       # int4 in flat regime
    trans_bits    = 8,       # int8 in transition
    obs_bits      = 16,      # fp16 in obstructed
    flat_heads    = 0.5,     # fraction of query heads in flat regime
    trans_heads   = 0.75,    # fraction in transition
    f_flat_max    = 1e-2,    # calibrated for Qwen2.5-3B-Instruct
    f_obs_min     = 50.0,    # calibrated for Qwen2.5-3B-Instruct
    max_ctx_flat  = 64,      # KV window in flat regime (tokens)
    max_ctx_trans = 128,     # KV window in transition (tokens)
    max_ctx_obs   = 128,     # KV window in obstructed (tokens)
)

How it works

SAM-Gate estimates the Morse functional f(t) via Hutchinson probing — O(K·H·d) instead of O(d³) eigendecomposition:

Δ_t  = Σ_i (T_i - I)*(T_i - I)     ← semantic Laplacian
f(t) = Σ_{i<j} ||[T_i, T_j]||²_F   ← harmonic curvature

At decode time, the policy is cached per layer — zero overhead from the prober after prefill.

The KV cache uses a dense ring buffer (fp16, fixed capacity) initialized once from the quantized prefill chunks. At each decode step, reconstruct_for_attn is O(1) — a single slice, no dequantization loop.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sam_gate-0.1.3.tar.gz (68.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sam_gate-0.1.3-py3-none-any.whl (63.4 kB view details)

Uploaded Python 3

File details

Details for the file sam_gate-0.1.3.tar.gz.

File metadata

  • Download URL: sam_gate-0.1.3.tar.gz
  • Upload date:
  • Size: 68.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.8

File hashes

Hashes for sam_gate-0.1.3.tar.gz
Algorithm Hash digest
SHA256 659ac8af2c4daa21a0e415b1d502454613194dfd92ba6c58114305408e335d31
MD5 8ef448aad7be9a930199b16b0f5a5130
BLAKE2b-256 bd40601633fcf183b83da40ed16d50c9e23b6c321379f515f4a49ffc6fa3af38

See more details on using hashes here.

File details

Details for the file sam_gate-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: sam_gate-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 63.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.8

File hashes

Hashes for sam_gate-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 6149168385514f2f2060386d72e7f2a03d2b8655dd25cffcd22d8fa1cbf3c0c0
MD5 e199c6c2918b5522998ec41c792ba9ae
BLAKE2b-256 af07af00fa2b3d9d4c156628e9bd64bbf0044fb854f770ed55ac36035b7e3430

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page