Skip to main content

No project description provided

Project description

Accelerated Scan

PyPI Version DOI

This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.

The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.

The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.

The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block.

A similar implementation is available in accelerated_scan.scalar using a Triton's tl.associative_scan primitive. It requires at least Triton 2.2 for its enable_fp_fusion flag.

Quick Start:

pip install accelerated-scan
import torch
from accelerated_scan.scalar import scan # uses tl.associative_scan in chunks
#from accelerated_scan.warp import scan # a pure c++ kernel, faster than cub
#from accelerated_scan.ref import scan # reference torch implementation

batch_size, dim, seqlen = 1, 512, 131072
forget = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
inputs = torch.rand(batch_size, dim, seqlen, device="cuda")

out = scan(forget, inputs)

To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using torch.compile.

Benchmarks:

bench.png

See more benchmarks in nanokitchen: https://github.com/proger/nanokitchen

forward speed of (8,1536,seqlen), forward-only, accelerated-scan version 0.2.0:

   SEQUENCE_LENGTH  accelerated_scan.triton (triton 2.2.0)  accelerated_scan.ref  accelerated_scan.warp
0            128.0                                0.027382              0.380874               0.026844
1            256.0                                0.049104              0.567916               0.048593
2            512.0                                0.093008              1.067906               0.092923
3           1024.0                                0.181856              2.048471               0.183581
4           2048.0                                0.358250              3.995369               0.355414
5           4096.0                                0.713511              7.897022               0.714536
6           8192.0                                1.433052             15.698944               1.411390
7          16384.0                                3.260965             31.305046               2.817152
8          32768.0                               31.459671             62.557182               5.645697
9          65536.0                               66.787331            125.208572              11.297921

Notes on Precision

When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):

max-abs-error.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

accelerated_scan-0.3.0.tar.gz (76.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

accelerated_scan-0.3.0-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file accelerated_scan-0.3.0.tar.gz.

File metadata

  • Download URL: accelerated_scan-0.3.0.tar.gz
  • Upload date:
  • Size: 76.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for accelerated_scan-0.3.0.tar.gz
Algorithm Hash digest
SHA256 0d4ee4e33c48dcebb4fdcc41f413607c5025dd5695785df6c3a1973ba613c3ce
MD5 6c410465dc66c8afbeb12eca5b558b8d
BLAKE2b-256 4ed6afa3cecd646d31aa73bd45abd615b4504c973e5928757f313e904884aa9c

See more details on using hashes here.

File details

Details for the file accelerated_scan-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for accelerated_scan-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 78b7737b72d80ae12b3e815c6d0455f955a89c1b1c0538385f4fd72ad4b166a2
MD5 f167ca1ed2c94609b84687ec9bdac857
BLAKE2b-256 13bb32981af0e832da08d77fa2c6f991fe137e33e6738dbcfb5535fcf787c557

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page