Skip to main content

No project description provided

Project description

Accelerated Scan

PyPI Version DOI

This package implements the fastest first-order parallel associative scan on the GPU for forward and backward.

The scan efficiently solves first-order recurrences of the form x[t] = gate[t] * x[t-1] + token[t], common in state space models and linear RNNs.

The accelerated_scan.warp C++ CUDA kernel uses a chunked processing algorithm that leverages the fastest GPU communication primitives available on each level of hierarchy: warp shuffles within warps of 32 threads and shared memory (SRAM) between warps within a thread block. One sequence per channel dimension is confined to one thread block.

The derivation of Chunked Scan has been used to extend tree-level Blelloch algorithm to block.

A similar implementation is available in accelerated_scan.triton using a Triton's tl.associative_scan primitive. It requires Triton 2.2 for its enable_fp_fusion flag.

Quick Start:

pip install accelerated-scan
import torch
from accelerated_scan.warp import scan # a pure c++ kernel, faster than cub
#from accelerated_scan.triton import scan # uses tl.associative_scan
#from accelerated_scan.ref import scan # reference torch implementation

# sequence lengths must be a power of 2 of lengths between 32 and 65536
# hit me up if you need different lengths!

batch_size, dim, seqlen = 3, 1536, 4096
gates = 0.999 + 0.001 * torch.rand(batch_size, dim, seqlen, device="cuda")
tokens = torch.rand(batch_size, dim, seqlen, device="cuda")

out = scan(gates, tokens)

To ensure numerical equivalence, a reference implementation for trees is provided in Torch. It can be sped up using torch.compile.

Benchmarks:

bench.png

See more benchmarks in nanokitchen: https://github.com/proger/nanokitchen

forward speed of (8,1536,seqlen), inference mode:

   SEQUENCE_LENGTH  accelerated_scan.triton (triton 2.2.0)  accelerated_scan.ref  accelerated_scan.warp
0            128.0                                0.027382              0.380874               0.026844
1            256.0                                0.049104              0.567916               0.048593
2            512.0                                0.093008              1.067906               0.092923
3           1024.0                                0.181856              2.048471               0.183581
4           2048.0                                0.358250              3.995369               0.355414
5           4096.0                                0.713511              7.897022               0.714536
6           8192.0                                1.433052             15.698944               1.411390
7          16384.0                                3.260965             31.305046               2.817152
8          32768.0                               31.459671             62.557182               5.645697
9          65536.0                               66.787331            125.208572              11.297921

Notes on Precision

When gates and tokens are sampled uniformly from 0..1 the lack of bfloat16 precision dominates the error (compared to the reference implementation):

max-abs-error.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

accelerated_scan-0.2.0.tar.gz (74.5 kB view details)

Uploaded Source

Built Distribution

accelerated_scan-0.2.0-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file accelerated_scan-0.2.0.tar.gz.

File metadata

  • Download URL: accelerated_scan-0.2.0.tar.gz
  • Upload date:
  • Size: 74.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for accelerated_scan-0.2.0.tar.gz
Algorithm Hash digest
SHA256 967509e7e16ba2500184ab7160f8fbf850b32e813b65dfe926cb07e4f9eb8d4b
MD5 0626d7c1bfcfacb600341e23e1036aca
BLAKE2b-256 eeee8920c5f35b2cccb8fae8ae35866983ee92bd9835c8c37edfbb107ea9cea9

See more details on using hashes here.

File details

Details for the file accelerated_scan-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for accelerated_scan-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3eae1ce400392ad1e37eddb6fe0262184bf28775e20e72977ab6dbd0ec5c09ae
MD5 c2d89fc2bd8d44136e8103c6db300cc6
BLAKE2b-256 01472760f95bcac10ee85aaf6e8e18fd89c5be4243f337d2d74f44db01add364

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page