No project description provided
Project description
Accelerated Scan
This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.
The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.
The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available
on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.
The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block.
A similar implementation is available in accelerated_scan.scalar using a Triton's tl.associative_scan primitive. It requires at least Triton 2.2 for its enable_fp_fusion flag.
Quick Start:
pip install accelerated-scan
import torch
from accelerated_scan.scalar import scan # uses tl.associative_scan in chunks
#from accelerated_scan.warp import scan # a pure c++ kernel, faster than cub
#from accelerated_scan.ref import scan # reference torch implementation
batch_size, dim, seqlen = 1, 512, 131072
forget = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
inputs = torch.rand(batch_size, dim, seqlen, device="cuda")
out = scan(forget, inputs)
To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using torch.compile.
Benchmarks:
See more benchmarks in nanokitchen: https://github.com/proger/nanokitchen
forward speed of (8,1536,seqlen), forward-only, accelerated-scan version 0.2.0:
SEQUENCE_LENGTH accelerated_scan.triton (triton 2.2.0) accelerated_scan.ref accelerated_scan.warp
0 128.0 0.027382 0.380874 0.026844
1 256.0 0.049104 0.567916 0.048593
2 512.0 0.093008 1.067906 0.092923
3 1024.0 0.181856 2.048471 0.183581
4 2048.0 0.358250 3.995369 0.355414
5 4096.0 0.713511 7.897022 0.714536
6 8192.0 1.433052 15.698944 1.411390
7 16384.0 3.260965 31.305046 2.817152
8 32768.0 31.459671 62.557182 5.645697
9 65536.0 66.787331 125.208572 11.297921
Notes on Precision
When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file accelerated_scan-0.3.1.tar.gz.
File metadata
- Download URL: accelerated_scan-0.3.1.tar.gz
- Upload date:
- Size: 77.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cb1f38993e5c81f52ecc3ed4b8060ab8ee7292beacc818f173b22341331c9c9e
|
|
| MD5 |
dfa70235f1d8989555f1f5aafd798aab
|
|
| BLAKE2b-256 |
8778f6a8ea3be54e0272388ac089d7819dbd1c4f47c7d0026d9e92278c775dcd
|
File details
Details for the file accelerated_scan-0.3.1-py3-none-any.whl.
File metadata
- Download URL: accelerated_scan-0.3.1-py3-none-any.whl
- Upload date:
- Size: 13.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
81676a54162bb7b522bbf9eee3a093b18c327be92f3f4c82aab3bc3d6372ec97
|
|
| MD5 |
68ee76bfef2d1caf0fc92ec113384eca
|
|
| BLAKE2b-256 |
e389cc53b59d0a6cd6ba1d2270764121c7b1ab66c5553ecbe445adb6f872435e
|