Skip to main content

Memory-efficient tiled online-softmax attention with fused GQA KV expansion, tuned for Pascal and later NVIDIA GPUs

Project description

pascal-attn

Memory-efficient tiled online-softmax attention with fused GQA KV expansion, tuned for Pascal and later NVIDIA GPUs.

What it is

pascal-attn implements the tiled online-softmax attention algorithm (a pure-PyTorch variant of FlashAttention) with two improvements over a naive chunked implementation:

  1. Fused GQA KV expansion — in Grouped-Query Attention (GQA), the key/value tensors have fewer heads than the query (n_kv < n_h). A naive implementation expands [B, n_kv, N, d_h][B, n_h, N, d_h] upfront. pascal-attn slices the unexpanded KV per tile and expands only that small [B, n_h, tile, d_h] slice — peak KV memory is constant in sequence length.

  2. Tile size tuned to GPU L2 cache — the dominant cost in the inner loop is the QK matmul. Keeping the tile small enough to fit in L2 eliminates cache thrashing. The library includes auto-detection via recommended_tile_size().

Installation

pip install pascal-attn

With HuggingFace Transformers integration:

pip install "pascal-attn[transformers]"

From source:

git clone https://github.com/Hraisikai/pascal-attn
cd pascal-attn
pip install -e .

Quick start

Functional API

from pascal_attn import tiled_attention, recommended_tile_size

tile = recommended_tile_size()   # auto-detect from GPU L2

# MHA
out = tiled_attention(q, k, v, tile_size=tile)

# GQA — n_kv < n_h, inferred from tensor shapes
out = tiled_attention(q, k_gqa, v_gqa, tile_size=tile)

# With causal mask
import torch
N = q.shape[2]
causal = torch.zeros(1, 1, N, N).masked_fill(
    torch.ones(N, N, dtype=torch.bool).triu(1), -1e9
)
out = tiled_attention(q, k, v, mask=causal, tile_size=tile)

Input/output shapes:

query:  [B, n_h,  N_q, d_h]
key:    [B, n_kv, N_k, d_h]   # n_kv <= n_h, n_h % n_kv == 0
value:  [B, n_kv, N_k, d_h]
output: [B, N_q,  n_h, d_h]

nn.Module API

from pascal_attn import TiledAttention

attn = TiledAttention(
    n_heads=32,
    n_kv_heads=8,       # GQA with 4 groups
    head_dim=64,
    tile_size='auto',   # detect from GPU at first forward call
)

# Packed inputs [B, N, H]
q = torch.randn(4, 2048, 32 * 64)
k = torch.randn(4, 2048,  8 * 64)
v = torch.randn(4, 2048,  8 * 64)
out = attn(q, k, v)   # [4, 2048, 32 * 64]

# Or unpacked inputs [B, N, n_h, d_h]
q = torch.randn(4, 2048, 32, 64)
k = torch.randn(4, 2048,  8, 64)
v = torch.randn(4, 2048,  8, 64)
out = attn(q, k, v)   # [4, 2048, 32, 64]

HuggingFace Transformers integration

from pascal_attn.hf import register_with_transformers

# Register once before loading your model
register_with_transformers(tile_size=64, name="pascal_chunked")

# Now any model config can use it
config._attn_implementation = "pascal_chunked"
model = AutoModelForCausalLM.from_pretrained("...", config=config)

Or assign per-layer:

from pascal_attn.hf import make_hf_attention_fn

fn = make_hf_attention_fn(tile_size=64)
for layer in model.model.layers:
    layer.self_attn._attention_forward_fn = fn

Memory savings

Configuration: N=2048, n_h=32, n_kv=8 (GQA 4:1), d_h=64, B=4, fp16.

Implementation KV peak per layer Attn scores peak Total peak (approx)
Naive (expanded KV) 84 MB 536 MB ~620 MB
tiled tile=256 84 MB* 33 MB ~117 MB
tiled tile=64 2.6 MB 2.1 MB ~5 MB

* tile=256 without fused GQA expansion still expands KV upfront. pascal-attn with tile=64 avoids both the upfront expansion and the large scores buffer.

Across a 28-layer 3B model: ~2.4 GB saved from fused GQA alone (84 MB × 28), before accounting for the scores buffer reduction.

GPU tile size guide

GPU family L2 cache Recommended tile_size
Pascal (GTX 1080 Ti, P40) 3 MB 64
Volta (V100) / Turing (T4, 20xx) 6–8 MB 128
Ampere (A100, 30xx) / Ada (40xx) 20–80 MB 256

Use recommended_tile_size() to detect automatically:

from pascal_attn import recommended_tile_size
tile = recommended_tile_size()
print(tile)  # e.g. 64 on a P40

The heuristic queries torch.cuda.get_device_properties().l2_cache_size and returns a conservative tile that fits within L2.

How it works

The algorithm is a pure-PyTorch implementation of the online-softmax tiling used in FlashAttention, extended with fused per-tile GQA expansion:

For each query tile q_i  [B, n_h, tile, d_h]:
    m = -1e9, l = 0, o = 0          # running max, denominator, output

    For each KV tile (k_raw, v_raw)  [B, n_kv, tile, d_h]:
        k_tile = expand(k_raw, n_groups)   # [B, n_h, tile, d_h]  ← fused
        v_tile = expand(v_raw, n_groups)

        s     = q_i @ k_tile.T * scale    # [B, n_h, tile, tile]
        s    += mask[q_i, k_j]            # optional
        m_new = max(m, rowmax(s))
        o     = exp(m - m_new) * o  +  exp(s - m_new) @ v_tile
        l     = exp(m - m_new) * l  +  rowsum(exp(s - m_new))
        m     = m_new

    output[q_i] = o / (l + 1e-8)        # normalise once per query tile

Key numerical choices:

  • Accumulators m, l, o are kept in float32 even for fp16/bf16 inputs.
  • m is initialised to -1e9 (not -inf) to avoid nan in exp(-inf - (-inf)) when every position in a tile is masked.
  • l + 1e-8 guards normalisation against fully-masked rows.

Running tests

pip install -e ".[dev]"
pytest tests/ -v

Correctness tests run on CPU; VRAM tests are skipped automatically if CUDA is not available.

Running benchmarks

# Default: CUDA, fp16, tile=auto, N=[512,1024,2048,4096]
python benchmarks/benchmark.py

# CPU run (fp32)
python benchmarks/benchmark.py --device cpu --dtype fp32 --seq-lens 128 256 512

# Custom config
python benchmarks/benchmark.py --tile 128 --batch 2 --heads 16 --kv-heads 4

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pascal_attn-0.1.0.tar.gz (16.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pascal_attn-0.1.0-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file pascal_attn-0.1.0.tar.gz.

File metadata

  • Download URL: pascal_attn-0.1.0.tar.gz
  • Upload date:
  • Size: 16.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for pascal_attn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3154c124827c044b63b0270231ef3133219555841c8a674d4bf2f4feaae4eac5
MD5 d338fdef39d9e115516c3d1d1c031f5f
BLAKE2b-256 72da44b76d1881bd51da3c3318b9dcedf5298bf6428efd92e9709fc8997b8d36

See more details on using hashes here.

File details

Details for the file pascal_attn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pascal_attn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 14.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for pascal_attn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8751e767ee611a388e864e17aadfe3db79d86e1a76900a9d7eb405f2db911124
MD5 da95f4efc8b431823abc93f345a86c57
BLAKE2b-256 d8eda0443a9058b595706e61f48e7133dd7bf26ea3fad58a16c74db7b346b057

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page