Skip to main content

No project description provided

Project description

Accelerated Scan

PyPI Version DOI

This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.

The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.

The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.

The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block.

A similar implementation is available in accelerated_scan.scalar using a Triton's tl.associative_scan primitive. It requires at least Triton 2.2 for its enable_fp_fusion flag.

Quick Start:

pip install accelerated-scan
import torch
from accelerated_scan.scalar import scan # uses tl.associative_scan in chunks
#from accelerated_scan.warp import scan # a pure c++ kernel, faster than cub
#from accelerated_scan.ref import scan # reference torch implementation

batch_size, dim, seqlen = 1, 512, 131072
forget = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
inputs = torch.rand(batch_size, dim, seqlen, device="cuda")

out = scan(forget, inputs)

To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using torch.compile.

Benchmarks:

bench.png

See more benchmarks in nanokitchen: https://github.com/proger/nanokitchen

forward speed of (8,1536,seqlen), forward-only, accelerated-scan version 0.2.0:

   SEQUENCE_LENGTH  accelerated_scan.triton (triton 2.2.0)  accelerated_scan.ref  accelerated_scan.warp
0            128.0                                0.027382              0.380874               0.026844
1            256.0                                0.049104              0.567916               0.048593
2            512.0                                0.093008              1.067906               0.092923
3           1024.0                                0.181856              2.048471               0.183581
4           2048.0                                0.358250              3.995369               0.355414
5           4096.0                                0.713511              7.897022               0.714536
6           8192.0                                1.433052             15.698944               1.411390
7          16384.0                                3.260965             31.305046               2.817152
8          32768.0                               31.459671             62.557182               5.645697
9          65536.0                               66.787331            125.208572              11.297921

Notes on Precision

When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):

max-abs-error.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

accelerated_scan-0.3.1.tar.gz (77.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

accelerated_scan-0.3.1-py3-none-any.whl (13.8 kB view details)

Uploaded Python 3

File details

Details for the file accelerated_scan-0.3.1.tar.gz.

File metadata

  • Download URL: accelerated_scan-0.3.1.tar.gz
  • Upload date:
  • Size: 77.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for accelerated_scan-0.3.1.tar.gz
Algorithm Hash digest
SHA256 cb1f38993e5c81f52ecc3ed4b8060ab8ee7292beacc818f173b22341331c9c9e
MD5 dfa70235f1d8989555f1f5aafd798aab
BLAKE2b-256 8778f6a8ea3be54e0272388ac089d7819dbd1c4f47c7d0026d9e92278c775dcd

See more details on using hashes here.

File details

Details for the file accelerated_scan-0.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for accelerated_scan-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 81676a54162bb7b522bbf9eee3a093b18c327be92f3f4c82aab3bc3d6372ec97
MD5 68ee76bfef2d1caf0fc92ec113384eca
BLAKE2b-256 e389cc53b59d0a6cd6ba1d2270764121c7b1ab66c5553ecbe445adb6f872435e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page