No project description provided
Project description
Accelerated Scan
This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.
The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t]
, common in state space models and linear RNNs.
The accelerated_scan.warp
C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available
on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.
The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block.
A similar implementation is available in accelerated_scan.triton
using a Triton's tl.associative_scan
primitive. It requires Triton 2.2 for its enable_fp_fusion
flag.
Quick Start:
pip install accelerated-scan
import torch
from accelerated_scan.warp import scan # a pure c++ kernel, faster than cub
#from accelerated_scan.triton import scan # uses tl.associative_scan
#from accelerated_scan.ref import scan # reference torch implementation
# sequence lengths must be a power of 2 of lengths between 32 and 65536
# hit me up if you need different lengths!
batch_size, dim, seqlen = 3, 1536, 4096
gates = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
tokens = torch.rand(batch_size, dim, seqlen, device="cuda")
out = scan(gates, tokens)
To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using torch.compile
.
Benchmarks:
See more benchmarks in nanokitchen: https://github.com/proger/nanokitchen
forward speed of (8,1536,seqlen), inference mode:
SEQUENCE_LENGTH accelerated_scan.triton (triton 2.2.0) accelerated_scan.ref accelerated_scan.warp
0 128.0 0.027382 0.380874 0.026844
1 256.0 0.049104 0.567916 0.048593
2 512.0 0.093008 1.067906 0.092923
3 1024.0 0.181856 2.048471 0.183581
4 2048.0 0.358250 3.995369 0.355414
5 4096.0 0.713511 7.897022 0.714536
6 8192.0 1.433052 15.698944 1.411390
7 16384.0 3.260965 31.305046 2.817152
8 32768.0 31.459671 62.557182 5.645697
9 65536.0 66.787331 125.208572 11.297921
Notes on Precision
When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file accelerated_scan-0.2.0.tar.gz
.
File metadata
- Download URL: accelerated_scan-0.2.0.tar.gz
- Upload date:
- Size: 74.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 967509e7e16ba2500184ab7160f8fbf850b32e813b65dfe926cb07e4f9eb8d4b |
|
MD5 | 0626d7c1bfcfacb600341e23e1036aca |
|
BLAKE2b-256 | eeee8920c5f35b2cccb8fae8ae35866983ee92bd9835c8c37edfbb107ea9cea9 |
File details
Details for the file accelerated_scan-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: accelerated_scan-0.2.0-py3-none-any.whl
- Upload date:
- Size: 11.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3eae1ce400392ad1e37eddb6fe0262184bf28775e20e72977ab6dbd0ec5c09ae |
|
MD5 | c2d89fc2bd8d44136e8103c6db300cc6 |
|
BLAKE2b-256 | 01472760f95bcac10ee85aaf6e8e18fd89c5be4243f337d2d74f44db01add364 |