Memory-efficient tiled online-softmax attention with fused GQA KV expansion, tuned for Pascal and later NVIDIA GPUs
Project description
pascal-attn
Memory-efficient tiled online-softmax attention with fused GQA KV expansion, tuned for Pascal and later NVIDIA GPUs.
What it is
pascal-attn implements the tiled online-softmax attention algorithm (a pure-PyTorch variant of FlashAttention) with two improvements over a naive chunked implementation:
-
Fused GQA KV expansion — in Grouped-Query Attention (GQA), the key/value tensors have fewer heads than the query (
n_kv < n_h). A naive implementation expands[B, n_kv, N, d_h]→[B, n_h, N, d_h]upfront.pascal-attnslices the unexpanded KV per tile and expands only that small[B, n_h, tile, d_h]slice — peak KV memory is constant in sequence length. -
Tile size tuned to GPU L2 cache — the dominant cost in the inner loop is the QK matmul. Keeping the tile small enough to fit in L2 eliminates cache thrashing. The library includes auto-detection via
recommended_tile_size().
Installation
pip install pascal-attn
With HuggingFace Transformers integration:
pip install "pascal-attn[transformers]"
From source:
git clone https://github.com/Hraisikai/pascal-attn
cd pascal-attn
pip install -e .
Quick start
Functional API
from pascal_attn import tiled_attention, recommended_tile_size
tile = recommended_tile_size() # auto-detect from GPU L2
# MHA
out = tiled_attention(q, k, v, tile_size=tile)
# GQA — n_kv < n_h, inferred from tensor shapes
out = tiled_attention(q, k_gqa, v_gqa, tile_size=tile)
# With causal mask
import torch
N = q.shape[2]
causal = torch.zeros(1, 1, N, N).masked_fill(
torch.ones(N, N, dtype=torch.bool).triu(1), -1e9
)
out = tiled_attention(q, k, v, mask=causal, tile_size=tile)
Input/output shapes:
query: [B, n_h, N_q, d_h]
key: [B, n_kv, N_k, d_h] # n_kv <= n_h, n_h % n_kv == 0
value: [B, n_kv, N_k, d_h]
output: [B, N_q, n_h, d_h]
nn.Module API
from pascal_attn import TiledAttention
attn = TiledAttention(
n_heads=32,
n_kv_heads=8, # GQA with 4 groups
head_dim=64,
tile_size='auto', # detect from GPU at first forward call
)
# Packed inputs [B, N, H]
q = torch.randn(4, 2048, 32 * 64)
k = torch.randn(4, 2048, 8 * 64)
v = torch.randn(4, 2048, 8 * 64)
out = attn(q, k, v) # [4, 2048, 32 * 64]
# Or unpacked inputs [B, N, n_h, d_h]
q = torch.randn(4, 2048, 32, 64)
k = torch.randn(4, 2048, 8, 64)
v = torch.randn(4, 2048, 8, 64)
out = attn(q, k, v) # [4, 2048, 32, 64]
HuggingFace Transformers integration
from pascal_attn.hf import register_with_transformers
# Register once before loading your model
register_with_transformers(tile_size=64, name="pascal_chunked")
# Now any model config can use it
config._attn_implementation = "pascal_chunked"
model = AutoModelForCausalLM.from_pretrained("...", config=config)
Or assign per-layer:
from pascal_attn.hf import make_hf_attention_fn
fn = make_hf_attention_fn(tile_size=64)
for layer in model.model.layers:
layer.self_attn._attention_forward_fn = fn
Memory savings
Configuration: N=2048, n_h=32, n_kv=8 (GQA 4:1), d_h=64, B=4, fp16.
| Implementation | KV peak per layer | Attn scores peak | Total peak (approx) |
|---|---|---|---|
| Naive (expanded KV) | 84 MB | 536 MB | ~620 MB |
| tiled tile=256 | 84 MB* | 33 MB | ~117 MB |
| tiled tile=64 | 2.6 MB | 2.1 MB | ~5 MB |
* tile=256 without fused GQA expansion still expands KV upfront. pascal-attn with tile=64 avoids both the upfront expansion and the large scores buffer.
Across a 28-layer 3B model: ~2.4 GB saved from fused GQA alone (84 MB × 28), before accounting for the scores buffer reduction.
GPU tile size guide
| GPU family | L2 cache | Recommended tile_size |
|---|---|---|
| Pascal (GTX 1080 Ti, P40) | 3 MB | 64 |
| Volta (V100) / Turing (T4, 20xx) | 6–8 MB | 128 |
| Ampere (A100, 30xx) / Ada (40xx) | 20–80 MB | 256 |
Use recommended_tile_size() to detect automatically:
from pascal_attn import recommended_tile_size
tile = recommended_tile_size()
print(tile) # e.g. 64 on a P40
The heuristic queries torch.cuda.get_device_properties().l2_cache_size and
returns a conservative tile that fits within L2.
How it works
The algorithm is a pure-PyTorch implementation of the online-softmax tiling used in FlashAttention, extended with fused per-tile GQA expansion:
For each query tile q_i [B, n_h, tile, d_h]:
m = -1e9, l = 0, o = 0 # running max, denominator, output
For each KV tile (k_raw, v_raw) [B, n_kv, tile, d_h]:
k_tile = expand(k_raw, n_groups) # [B, n_h, tile, d_h] ← fused
v_tile = expand(v_raw, n_groups)
s = q_i @ k_tile.T * scale # [B, n_h, tile, tile]
s += mask[q_i, k_j] # optional
m_new = max(m, rowmax(s))
o = exp(m - m_new) * o + exp(s - m_new) @ v_tile
l = exp(m - m_new) * l + rowsum(exp(s - m_new))
m = m_new
output[q_i] = o / (l + 1e-8) # normalise once per query tile
Key numerical choices:
- Accumulators
m,l,oare kept in float32 even for fp16/bf16 inputs. mis initialised to-1e9(not-inf) to avoidnaninexp(-inf - (-inf))when every position in a tile is masked.l + 1e-8guards normalisation against fully-masked rows.
Running tests
pip install -e ".[dev]"
pytest tests/ -v
Correctness tests run on CPU; VRAM tests are skipped automatically if CUDA is not available.
Running benchmarks
# Default: CUDA, fp16, tile=auto, N=[512,1024,2048,4096]
python benchmarks/benchmark.py
# CPU run (fp32)
python benchmarks/benchmark.py --device cpu --dtype fp32 --seq-lens 128 256 512
# Custom config
python benchmarks/benchmark.py --tile 128 --batch 2 --heads 16 --kv-heads 4
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pascal_attn-0.1.0.tar.gz.
File metadata
- Download URL: pascal_attn-0.1.0.tar.gz
- Upload date:
- Size: 16.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3154c124827c044b63b0270231ef3133219555841c8a674d4bf2f4feaae4eac5
|
|
| MD5 |
d338fdef39d9e115516c3d1d1c031f5f
|
|
| BLAKE2b-256 |
72da44b76d1881bd51da3c3318b9dcedf5298bf6428efd92e9709fc8997b8d36
|
File details
Details for the file pascal_attn-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pascal_attn-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8751e767ee611a388e864e17aadfe3db79d86e1a76900a9d7eb405f2db911124
|
|
| MD5 |
da95f4efc8b431823abc93f345a86c57
|
|
| BLAKE2b-256 |
d8eda0443a9058b595706e61f48e7133dd7bf26ea3fad58a16c74db7b346b057
|