Skip to main content

Bounded-memory differential attention with value-routed landmark banks

Project description

CoDA-GQA-L

Bounded-memory differential attention that actually works.

pip install coda-gqa-l

A 70B model serving 128K context burns 160 GB on KV cache alone. CoDA-GQA-L does it in 136 MB -- a fixed-size buffer that doesn't grow no matter how long the input is.

Context Standard KV (70B, 80 layers) CoDA-GQA-L Compression
2K 2.56 GB 136 MB 18.8x
32K 40 GB 136 MB 294x
128K 160 GB 136 MB 1,176x

The mechanism replaces the O(L) KV cache with three fixed-size segments per layer:

  • Recent window (W=256) -- ring buffer of exact recent tokens, FIFO eviction
  • Exact landmark bank (Me=64) -- novelty-filtered LRU cache for important tokens a learned write gate decides are worth keeping
  • Summary landmark bank (Ms=64) -- EMA prototypes compressing older context into semantic cluster centroids

384 slots per layer, always, whether you processed 2K or 128K tokens.

The bounded state is serializable. torch.save() it, load it a week later, query it without re-reading the original document. At 7B scale each state is 54 MB. We call this pattern stateful neural databases.

How it works

Three ideas, briefly:

Differential attention via orthogonal rotation. Builds on Microsoft's Diff Transformer (Ye et al., 2024) but drops the second Wq projection. A learned per-head rotation produces the noise query from the signal query. Signal minus gated noise, one SDPA call (head-stacked), HeadwiseRMSNorm on the output.

Value-routing. Keys have RoPE baked in and their cosine similarity is position-dependent -- same word at position 100 vs 5000 looks orthogonal in key-space. Memory banks route on Values instead (RoPE-free, pure semantic content). This is what makes deduplication and EMA blending actually work.

Two-phase training. Phase 1 trains unbounded differential attention. Phase 2 switches to bounded cache with gradient flow through evictions so the write gate learns what to keep. Without Phase 2, bounded eval is catastrophic (PPL 5.62 to 2,464 on Mistral 7B).

Quick start

import torch
from coda_gqa_l import CoDAGQALandmarkPerf2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32

attn = CoDAGQALandmarkPerf2(
    embed_dim=512,
    num_heads=8,
    num_kv_heads=2,          # GQA: 4x head sharing
    window=256,
    num_landmarks_exact=64,
    num_landmarks_summary=64,
).to(device=device, dtype=dtype).eval()

state = attn.init_state(batch_size=1, device=device, dtype=dtype)

# Prefill
prompt = torch.randn(1, 2048, 512, device=device, dtype=dtype)
y, state = attn.prefill_chunked(prompt, state, block_size=256)

# Decode -- memory stays constant
for _ in range(1000):
    x_t = torch.randn(1, 1, 512, device=device, dtype=dtype)
    y_t, state = attn.step(x_t, state)

print(f"Tokens seen: {state.pos}")        # 3048
print(f"Cache size: {attn.cache_bytes(1, dtype):,} bytes")  # constant

Stateful neural databases

# Ingest a document
state = attn.init_state(batch_size=1, device="cuda", dtype=torch.bfloat16)
_, state = attn.prefill_chunked(document_embeddings, state, block_size=256)

# Save -- fixed size regardless of document length
torch.save(state, "document_42.pt")

# Later: load and query without re-reading the document
state = torch.load("document_42.pt")
answer, state = attn.step(query_embedding, state)

At 7B scale (32 layers): 54 MB per document state. 100 documents = 5.4 GB. Route queries to the right state on demand -- no vector database, no chunk-and-retrieve.

Drop-in model adapters

Llama family (Llama 2/3, Mistral, SmolLM)

from coda_gqa_l import LlamaCoDAAdapter

adapter = LlamaCoDAAdapter.from_llama_attention(
    llama_block.self_attn,
    bounded=False,  # unbounded for training, True for inference
)
llama_block.self_attn = adapter

Eve-2 MoE

from coda_gqa_l import EveCoDAAdapter

adapter = EveCoDAAdapter.from_eve_attention(eve_block.attn, bounded=True)
eve_block.attn = adapter

Training

Two-phase protocol. Phase 1 teaches differential attention with full context. Phase 2 switches to bounded cache so the model adapts to limited memory.

# Smoke test (~5 min)
python benchmarks/train_coda.py \
    --model HuggingFaceTB/SmolLM2-135M \
    --max-steps 200 --bounded-steps 100 --bounded-config medium

# Mistral 7B (~6 hours on H100)
python benchmarks/train_coda.py \
    --model mistralai/Mistral-7B-v0.3 \
    --max-steps 2000 --bounded-steps 2000 \
    --bounded-lr-scale 0.5 --bounded-block-size 128 \
    --no-detach-evicted --batch-size 1 --grad-accum 8

Bounded configs:

Config Window Exact Summary Total slots
tiny 128 32 32 192
medium 256 64 64 384
large 512 128 128 768

Training results (Mistral-7B-v0.3 on H200 NVL)

Phase Steps PPL start PPL end Throughput
Phase 1 (unbounded) 2,000 23.50 5.75 ~4,950 tok/s
Phase 2 (bounded, medium) 600 27.88 6.31 ~2,000 tok/s

23.5% PPL overhead (+1.13 PPL) vs. the 4.81 baseline for 9.5x memory compression. Context-length scaling is remarkably flat: 5.94 at 2K, 5.95 at 4K. Total training time: ~1.6 hours on H200.

Phase 2 trains with detach_evicted=False so gradients flow through bank updates.

Benchmarks

All numbers from H200 NVL, bf16, standalone attention modules.

python benchmarks/run_suite.py                          # 5 configs, JSON output
python benchmarks/run_suite.py --embed-dim 4096 \       # 7B scale
    --num-heads 32 --num-kv-heads 8
python benchmarks/render_tables.py                      # markdown tables

Memory (per layer, medium-cache W=256 Me=64 Ms=64)

Scale Standard KV CoDA bounded Compression
7B (D=4096, H=32, Hkv=8) 32.0 MB 1.7 MB 18.8x
70B (D=8192, H=64, Hkv=8) 32.0 MB 1.7 MB 18.8x

Across all layers at 128K context:

Model Standard KV total CoDA total Compression
7B (32 layers) 64 GB 54 MB 1,185x
70B (80 layers) 160 GB 136 MB 1,176x

Throughput (70B scale, tokens/sec)

Config Prefill 2K Prefill 8K Decode Peak VRAM
Baseline GQA 1,336,349 966,464 4,676 1.1 GB
CoDA unbounded 889,203 598,314 2,914 1.6 GB
CoDA medium-cache 149,832 153,716 1,753 568.6 MB
CoDA window-only 359,417 356,026 1,773 546.0 MB

Bounded prefill throughput is flat (~150K tok/s) regardless of sequence length. Baseline drops 28% from 2K to 8K. Bounded VRAM is 1.9x lower.

Architecture

Attention flow

x -> q_proj -> RoPE(q) -> q_signal
                        \-> R(theta) -> q_noise

x -> k_proj -> RoPE(k) -> k  (stored in buffer)
x -> v_proj -> v             (stored in buffer)

SDPA([q_signal; q_noise], k_buf, v_buf)
  -> split -> out_sig, out_noise
  -> out_sig - lambda * out_noise
  -> HeadwiseRMSNorm -> o_proj -> output

Buffer layout

k_buf / v_buf: (B, Hkv, W + Me + Ms, Dh)

Slots:  [0..W-1]      [W..W+Me-1]      [W+Me..W+Me+Ms-1]
         recent ring    exact bank       summary bank

On eviction from the ring:

  1. Write gate check -- is this token worth remembering?
  2. Exact bank: V-routing cosine similarity. Novel? Insert. Duplicate? Update LRU.
  3. Summary bank: V-routing cosine similarity. EMA blend into best-matching prototype.

Dense packing

For B=1, valid slots are dense-packed before SDPA. This avoids the boolean mask that forces PyTorch into the slow Math backend, unlocking FlashAttention/MemEfficient kernels.

Correctness

Bounded is mathematically identical to unbounded when W >= L (no evictions):

W L Max diff
512 512 1.2e-7
1024 1024 1.5e-7
2048 2048 1.8e-7

56 tests covering correctness, determinism, edge configs, invariants, and backward pass safety.

python -m pytest tests/ -v

What's in the repo

src/coda_gqa_l/
  attention.py            CoDAGQALandmarkPerf2 (main module)
  memory_banks.py         Exact/summary bank updates (mixin)
  state.py                KV buffer + ring metadata dataclass
  primitives.py           RoPE, RMSNorm, GQA utils
  baseline.py             Unbounded CoDAGQA + standard GQA
  llama_adapter.py        Drop-in for Llama/Mistral/SmolLM
  eve_adapter.py          Drop-in for Eve-2 MoE
  triton_diff_flash/      Fused differential FlashAttention kernel
  triton_bank_routing/    Fused exact-bank routing kernel

benchmarks/
  train_coda.py           Two-phase training pipeline
  run_suite.py            Perf benchmarks (5 configs, JSON output)
  eval_llm.py             Full-model perplexity evaluation
  run_ablation_h100.sh    Differential attention ablation

tests/                    56 tests

Metrics

Optional instrumentation for understanding bank behavior:

attn = CoDAGQALandmarkPerf2(..., collect_metrics=True)
state = attn.init_state(...)

# After processing...
print(state.metrics)
# {'exact_hits': 42, 'exact_inserts': 18, 'exact_fill_ratio': 0.28,
#  'summary_updates': 156, 'summary_fill_ratio': 1.0,
#  'tokens_gated_out': 89, 'total_evictions': 768}

Zero overhead when disabled (default).

Custom Triton kernels

Two fused kernels address throughput bottlenecks (verified on H200 NVL with Triton 3.4.0):

  • triton_diff_flash -- Fused differential FlashAttention forward kernel. Single HBM pass computes both signal and noise attention with online softmax, applies the differential epilogue and optional HeadwiseRMSNorm in-register.
  • triton_bank_routing -- Fused exact-bank routing kernel replacing ~15 PyTorch micro-kernel launches (matmul, mean, max, masked_fill, topk, cumsum, clamp, gather, scatter, where) with a single deterministic GPU kernel.

Install with: pip install coda-gqa-l[triton]

Limitations

  • 2x attention FLOPs from dual-stream differential attention. Fused Triton forward kernel partially addresses this; backward pass kernel is future work.
  • Fine-tuning required. Cold-swap doesn't work -- differential attention reshapes activations (cold-swap PPL: 2,464).
  • +23.5% PPL gap between bounded and unbounded at 7B scale. Real information loss from context compression.
  • Configuration sensitivity: the bounded config used at inference must match training (large-cache trained with medium yields worse PPL).
  • No distributed cache sharding. No quantized KV storage.

Links

Citation

@software{coda_gqa_l_2026,
  title  = {CoDA-GQA-L: Bounded-Memory Differential Attention
            with Value-Routed Landmark Banks},
  author = {Maio, Anthony},
  year   = {2026},
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

coda_gqa_l-1.2.0.tar.gz (61.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

coda_gqa_l-1.2.0-py3-none-any.whl (55.0 kB view details)

Uploaded Python 3

File details

Details for the file coda_gqa_l-1.2.0.tar.gz.

File metadata

  • Download URL: coda_gqa_l-1.2.0.tar.gz
  • Upload date:
  • Size: 61.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for coda_gqa_l-1.2.0.tar.gz
Algorithm Hash digest
SHA256 57eec2e91084bb764414e79b15605874dea2f850ece66a9495605858a0fa9f14
MD5 eac8bd1b5416bb8d416787c908ee707f
BLAKE2b-256 db7e344a7c9bce7072dc0540ad2ace8ddf178074b0cd6fa87c90affaf29b1e4d

See more details on using hashes here.

Provenance

The following attestation bundles were made for coda_gqa_l-1.2.0.tar.gz:

Publisher: workflow.yml on anthony-maio/CoDA-GQA-L

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file coda_gqa_l-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: coda_gqa_l-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 55.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for coda_gqa_l-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f527268a910510f02914ae102e4ef3490ae7e15f9c00d762efa8ce79cec1e59f
MD5 27979fd0cf2675e5d49d38d014a424f1
BLAKE2b-256 3ed9238bb165298bea42af79aa675576e9a4c18ac71f493b9647afa6213762b1

See more details on using hashes here.

Provenance

The following attestation bundles were made for coda_gqa_l-1.2.0-py3-none-any.whl:

Publisher: workflow.yml on anthony-maio/CoDA-GQA-L

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page