Cross-framework ragged tensor primitive with reference varlen kernels
Project description
scree
A cross-framework ragged tensor primitive for variable-length sequence data.
import scree
import numpy as np
# Three sequences of different lengths.
seqs = [np.random.randn(n, 8).astype(np.float32) for n in [4, 2, 7]]
# Pack them into one scree.Array — no padding.
arr = scree.pack(seqs)
# arr.values: shape (13, 8), arr.offsets: [0, 4, 6, 13]
# Run varlen attention. Each sequence attends only to itself.
from scree.kernels.reference import varlen_attention
out = varlen_attention(arr, arr, arr, causal=True)
Why
Variable-length sequence data is everywhere in modern ML — transformer training, inference batching, multimodal interleaving, MoE routing — yet every team carries their own incompatible representation:
torch.nested(PyTorch only, in beta since 2021)- TF
RaggedTensor(TensorFlow only) - FlashAttention
cu_seqlens(a convention, not a typed primitive) - vLLM / SGLang packed batches (internal data structures)
- HuggingFace
attention_mask(pads, then masks — wasting memory and FLOPs)
scree ships one primitive — a packed values + offsets + ragged_dim
array — that bridges across frameworks and ships with reference varlen
kernels for attention, layernorm, softmax, and scatter/gather.
What you get
Memory savings vs HF padded on realistic LLM length distributions (log-normal):
| Workload | Mean savings | Min – Max |
|---|---|---|
| Training-style (batch 64, mean_len 256, σ=0.6) | 71% | 63% – 84% |
| Inference-style (batch 32, mean_len 1024, σ=1.2) | 85% | 75% – 94% |
Reproduce: python benchmarks/bench_memory.py
CPU throughput vs a naive padded attention baseline on a real batch (16 seqs × log-normal lengths, 1980 real / 4464 padded tokens, 4 heads × head_dim 32, fp32, no mask optimization):
| Operation | scree | padded baseline | Speedup |
|---|---|---|---|
| varlen_attention | 34.7 ms | 228.3 ms | 6.6× |
| varlen_rmsnorm | 0.13 ms | 0.28 ms | 2.1× |
Reproduce: python benchmarks/bench_throughput.py
GPU forward + training step vs FlashAttention-2 on H100. Headline workload: 16 seqs × log-normal lengths, 12160 total tokens, 16 heads × head_dim 64, fp16, causal:
| Operation | FA-2 | scree-Triton | Ratio |
|---|---|---|---|
| forward only | 0.165 ms | 0.216 ms | 1.30× |
| forward + backward (training step) | 0.688 ms | 1.106 ms | 1.61× |
Correctness: forward max abs diff 4.88e-04; backward dq 9.77e-04, dk
1.95e-03, dv 1.95e-03 vs FA-2 (all PASS within fp16 tolerance).
Reproduce: modal run benchmarks/modal_bench.py + modal run benchmarks/modal_autograd_bench.py (~$0.40 of Modal credit total).
Across 27 shapes (head_dim × n_heads × mean_len): scree is closer
to FA-2 on large workloads, slower on small ones (wrapper allocation
overhead is per-call; it amortizes as the kernel grows). See benchmarks/modal_multishape_sweep.py.
| Workload range | Forward ratio | Training-step ratio |
|---|---|---|
| Best (large: head_dim=64, n_heads=16, mean_len=2048) | 1.21× | 1.45× |
| Median across 27 shapes | 1.95× | 2.01× |
| Worst (toy: head_dim=32, n_heads=4, mean_len=256) | 3.53× | 1.77× |
For production LLM training (head_dim ≥ 64, n_heads ≥ 8, mean_len ≥ 1024), expect 1.2–2.0× of FA-2.
Zero-copy bridges to the things you already use:
import scree.bridges as bridges
arr = scree.from_cu_seqlens(values, cu_seqlens) # FlashAttention
arr = bridges.from_hf_padded(hidden_states, attn_mask) # HuggingFace
arr = bridges.from_torch_nested(nt) # torch.nested
bridges.to_torch_nested(arr) # → torch.NestedTensor
bridges.to_hf_padded(arr) # → (hidden_states, attention_mask)
bridges.to_torch(arr) # numpy values → torch tensors via DLPack
One primitive, every framework — values and offsets can be NumPy, PyTorch, MLX (Apple Silicon, via Metal), or JAX. All four backends pass the same correctness suite end-to-end.
The name
A scree is the irregular pile of rock fragments accumulated on a mountain slope. Variable-length sequences pack against each other the same way: irregular shapes, fitted by their irregularity, not despite it.
Status
v0.0.1, pre-alpha. The reference (slow but correct) Python kernels are present. Triton kernels at FlashAttention-varlen parity ship in v0.1.
| Component | Status |
|---|---|
scree.Array type + invariants |
✅ |
pack / unpack / to_padded / from_padded |
✅ |
| Reference varlen attention / layernorm / softmax | ✅ |
| Bridges: torch.nested, HF padded, FA cu_seqlens, DLPack | ✅ |
| NumPy + PyTorch + MLX + JAX backends | ✅ |
| Triton kernels at FA-varlen parity | 🟡 next |
| Triton autotune (Triton 3.1+) | 🟡 next |
Install
pip install scree # numpy backend
pip install "scree[torch]" # + PyTorch backend
pip install "scree[mlx]" # + MLX backend (Apple Silicon, Metal)
pip install "scree[jax]" # + JAX backend
Examples
examples/01_quickstart.py— pack/unpack + varlen attentionexamples/02_no_pad_transformer.py— full transformer block, zero paddingexamples/03_train_step.py— training step with PyTorch autograd flowing through scree (loss drops 80× in 30 steps)examples/04_hf_compat.py— HuggingFace Transformers migration recipe (drop-in viabridges.from_hf_padded/to_hf_padded)examples/05_multimodal_interleaved.py— interleaved text + image-patch sequences packed into one scree.Array
Documentation
- Getting started — install, first program, common patterns
- Concepts — the mental model behind
values + offsets + ragged_dim - API reference — every public function and class
- Bridges & migration — moving from
torch.nested, HuggingFace, FlashAttention - Kernels — reference and Triton kernel design
- Architecture — internal layout for contributors
- Benchmarks — methodology and reproduction
- FAQ
Contributing
PRs welcome. See CONTRIBUTING.md for the workflow. Open a GitHub Discussion for anything beyond a small fix.
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file scree-0.0.1.tar.gz.
File metadata
- Download URL: scree-0.0.1.tar.gz
- Upload date:
- Size: 30.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b554d1b0b698ab95a44bf0685da6e48bcad56b023d7dbae9fc4991d0fd389e6f
|
|
| MD5 |
1df9035445bb2340e196d4a29b89c454
|
|
| BLAKE2b-256 |
f2442bb265bc96d3e503b5a0a0c4b483f9c244450f8f66b0797961e265921b45
|
File details
Details for the file scree-0.0.1-py3-none-any.whl.
File metadata
- Download URL: scree-0.0.1-py3-none-any.whl
- Upload date:
- Size: 27.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4911c1a7581c135d6b7db4f14e512d8774f4e766295e3766b584468f1831eb34
|
|
| MD5 |
852a4b361408a96911f473d4a664106b
|
|
| BLAKE2b-256 |
b92d234c8abd84b5687a7ad1cddf2ef31d7429c94e593f1a3de3303f8d495f15
|