Skip to main content

Cross-framework ragged tensor primitive with reference varlen kernels

Project description

scree

A cross-framework ragged tensor primitive for variable-length sequence data.

import scree
import numpy as np

# Three sequences of different lengths.
seqs = [np.random.randn(n, 8).astype(np.float32) for n in [4, 2, 7]]

# Pack them into one scree.Array — no padding.
arr = scree.pack(seqs)
# arr.values: shape (13, 8), arr.offsets: [0, 4, 6, 13]

# Run varlen attention. Each sequence attends only to itself.
from scree.kernels.reference import varlen_attention
out = varlen_attention(arr, arr, arr, causal=True)

Why

Variable-length sequence data is everywhere in modern ML — transformer training, inference batching, multimodal interleaving, MoE routing — yet every team carries their own incompatible representation:

  • torch.nested (PyTorch only, in beta since 2021)
  • TF RaggedTensor (TensorFlow only)
  • FlashAttention cu_seqlens (a convention, not a typed primitive)
  • vLLM / SGLang packed batches (internal data structures)
  • HuggingFace attention_mask (pads, then masks — wasting memory and FLOPs)

scree ships one primitive — a packed values + offsets + ragged_dim array — that bridges across frameworks and ships with reference varlen kernels for attention, layernorm, softmax, and scatter/gather.

What you get

Memory savings vs HF padded on realistic LLM length distributions (log-normal):

Workload Mean savings Min – Max
Training-style (batch 64, mean_len 256, σ=0.6) 71% 63% – 84%
Inference-style (batch 32, mean_len 1024, σ=1.2) 85% 75% – 94%

Reproduce: python benchmarks/bench_memory.py

CPU throughput vs a naive padded attention baseline on a real batch (16 seqs × log-normal lengths, 1980 real / 4464 padded tokens, 4 heads × head_dim 32, fp32, no mask optimization):

Operation scree padded baseline Speedup
varlen_attention 34.7 ms 228.3 ms 6.6×
varlen_rmsnorm 0.13 ms 0.28 ms 2.1×

Reproduce: python benchmarks/bench_throughput.py

GPU forward + training step vs FlashAttention-2 on H100. Headline workload: 16 seqs × log-normal lengths, 12160 total tokens, 16 heads × head_dim 64, fp16, causal:

Operation FA-2 scree-Triton Ratio
forward only 0.165 ms 0.216 ms 1.30×
forward + backward (training step) 0.688 ms 1.106 ms 1.61×

Correctness: forward max abs diff 4.88e-04; backward dq 9.77e-04, dk 1.95e-03, dv 1.95e-03 vs FA-2 (all PASS within fp16 tolerance). Reproduce: modal run benchmarks/modal_bench.py + modal run benchmarks/modal_autograd_bench.py (~$0.40 of Modal credit total).

Across 27 shapes (head_dim × n_heads × mean_len): scree is closer to FA-2 on large workloads, slower on small ones (wrapper allocation overhead is per-call; it amortizes as the kernel grows). See benchmarks/modal_multishape_sweep.py.

Workload range Forward ratio Training-step ratio
Best (large: head_dim=64, n_heads=16, mean_len=2048) 1.21× 1.45×
Median across 27 shapes 1.95× 2.01×
Worst (toy: head_dim=32, n_heads=4, mean_len=256) 3.53× 1.77×

For production LLM training (head_dim ≥ 64, n_heads ≥ 8, mean_len ≥ 1024), expect 1.2–2.0× of FA-2.

Zero-copy bridges to the things you already use:

import scree.bridges as bridges

arr = scree.from_cu_seqlens(values, cu_seqlens)         # FlashAttention
arr = bridges.from_hf_padded(hidden_states, attn_mask)  # HuggingFace
arr = bridges.from_torch_nested(nt)                     # torch.nested

bridges.to_torch_nested(arr)   # → torch.NestedTensor
bridges.to_hf_padded(arr)      # → (hidden_states, attention_mask)
bridges.to_torch(arr)          # numpy values → torch tensors via DLPack

One primitive, every framework — values and offsets can be NumPy, PyTorch, MLX (Apple Silicon, via Metal), or JAX. All four backends pass the same correctness suite end-to-end.

The name

A scree is the irregular pile of rock fragments accumulated on a mountain slope. Variable-length sequences pack against each other the same way: irregular shapes, fitted by their irregularity, not despite it.

Status

v0.0.1, pre-alpha. The reference (slow but correct) Python kernels are present. Triton kernels at FlashAttention-varlen parity ship in v0.1.

Component Status
scree.Array type + invariants
pack / unpack / to_padded / from_padded
Reference varlen attention / layernorm / softmax
Bridges: torch.nested, HF padded, FA cu_seqlens, DLPack
NumPy + PyTorch + MLX + JAX backends
Triton kernels at FA-varlen parity 🟡 next
Triton autotune (Triton 3.1+) 🟡 next

Install

pip install scree              # numpy backend
pip install "scree[torch]"     # + PyTorch backend
pip install "scree[mlx]"       # + MLX backend (Apple Silicon, Metal)
pip install "scree[jax]"       # + JAX backend

Examples

Documentation

Contributing

PRs welcome. See CONTRIBUTING.md for the workflow. Open a GitHub Discussion for anything beyond a small fix.

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scree-0.0.1.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

scree-0.0.1-py3-none-any.whl (27.4 kB view details)

Uploaded Python 3

File details

Details for the file scree-0.0.1.tar.gz.

File metadata

  • Download URL: scree-0.0.1.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for scree-0.0.1.tar.gz
Algorithm Hash digest
SHA256 b554d1b0b698ab95a44bf0685da6e48bcad56b023d7dbae9fc4991d0fd389e6f
MD5 1df9035445bb2340e196d4a29b89c454
BLAKE2b-256 f2442bb265bc96d3e503b5a0a0c4b483f9c244450f8f66b0797961e265921b45

See more details on using hashes here.

File details

Details for the file scree-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: scree-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 27.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for scree-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4911c1a7581c135d6b7db4f14e512d8774f4e766295e3766b584468f1831eb34
MD5 852a4b361408a96911f473d4a664106b
BLAKE2b-256 b92d234c8abd84b5687a7ad1cddf2ef31d7429c94e593f1a3de3303f8d495f15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page