Associative scan implementation with MLX.
Project description
mlx-scan
Associative scan implementation with MLX.
Quickstart
To install:
uv add mlx-scan # uv
pip install mlx-scan # pip
For local development:
uv sync --extra dev --extra bench
To test and lint:
uv run pytest
uv run ruff check .
Associative Scan
A scan computes every prefix of a sequence:
[a, b, c, d] -> [a, a (*) b, a (*) b (*) c, a (*) b (*) c (*) d]
If the binary operator (*) is associative, the prefixes can be computed with parallel-prefix algorithms such as Hillis-Steele or Blelloch scan. This matters for more than sums: matrix products, affine recurrences, and sequence-model operations can be expressed with scan-like computations.
Example
The canonical example is a cumulative sum:
import mlx.core as mx
from mlx_scan import associative_scan
x = mx.array([1, 2, 3, 4])
associative_scan(lambda a, b: a + b, x)
# array([1, 3, 6, 10], dtype=int32)
associative_scan accepts an MLX array or a pytree of MLX arrays.
Performance
This implementation is useful as a generic, MLX-native associative scan, but it is not a replacement for a compiler-native scan primitive. Benchmarks were run on macOS 14.2.1 arm64 with Python 3.10.12, MLX 0.31.2, and JAX 0.6.2 CPU-only.
For addition, the baseline mx.cumsum is much faster than generic scan composition. On CPU, compiled median times in milliseconds:
| shape | MLX Hillis-Steele | MLX Blelloch | MLX cumsum | JAX lax scan |
|---|---|---|---|---|
| 1024 | 0.182 | 0.360 | 0.023 | 0.010 |
| 65536 | 0.536 | 0.481 | 0.080 | 0.149 |
| 512x1024 axis=1 | 3.145 | 1.481 | 0.483 | 0.823 |
For an SSM/linear-RNN style affine recurrence, there is no mx.cumsum equivalent. Each scan element is a pair (a, b) representing h -> a*h + b, with associative composition:
compose((a_left, b_left), (a_right, b_right)) = (
a_right * a_left,
a_right * b_left + b_right,
)
Compiled CPU median times in milliseconds:
| shape | MLX Hillis-Steele | MLX Blelloch | JAX lax scan |
|---|---|---|---|
| 1024 | 0.255 | 0.541 | 0.035 |
| 8192 | 0.345 | 0.929 | 0.106 |
| 128x256 axis=1 | 0.471 | 0.534 | 0.228 |
The main bottleneck is not plain Python loop overhead once compiled; it is the graph produced by expressing scan through repeated slice, concat, stack, reshape, and user-operator applications. MLX evaluates those composed graphs well enough to make the implementation usable, but it does not appear to apply the same scan-specific fusion/lowering that XLA applies to jax.lax.associative_scan.
For a Mamba-style prefill recurrence, adapted from the SSM loop in mlx-lm's mamba.py, the isolated recurrence state_t = decay_t * state_{t-1} + update_t is a better fit for the work-efficient Blelloch scan. With shape (batch=1, sequence, channels=256, state=16), compiled GPU median times in milliseconds:
| sequence | loop | Hillis-Steele | Blelloch |
|---|---|---|---|
| 128 | 2.654 | 0.888 | 0.995 |
| 512 | 8.469 | 1.973 | 0.899 |
| 1024 | 18.481 | 5.152 | 1.665 |
| 2048 | 43.396 | 10.206 | 1.898 |
The benchmark tables above were generated with:
uv run python benchmarks/bench_scan.py --compiled --op add --device cpu \
--include-jax \
--algorithm hillis_steele \
--algorithm blelloch \
--algorithm mlx_cumsum \
--algorithm jax_lax_associative_scan \
--runs 50 --warmup 10
uv run python benchmarks/bench_scan.py --compiled --op affine --device cpu \
--include-jax \
--algorithm hillis_steele \
--algorithm blelloch \
--algorithm jax_lax_associative_scan \
--runs 50 --warmup 10
uv run python benchmarks/bench_mamba_prefill.py --compiled --device gpu \
--runs 10 --warmup 3
References
@article{HillisSteele1986,
author = {Hillis, W. Daniel and Steele, Guy L., Jr.},
title = {Data Parallel Algorithms},
journal = {Communications of the ACM},
volume = {29},
number = {12},
pages = {1170--1183},
year = {1986},
doi = {10.1145/7902.7903}
}
@incollection{Blelloch1993,
author = {Blelloch, Guy E.},
title = {Prefix Sums and Their Applications},
booktitle = {Synthesis of Parallel Algorithms},
publisher = {Morgan Kaufmann},
year = {1993}
}
@inproceedings{MartinCundy2018,
author = {Martin, Eric and Cundy, Chris},
title = {Parallelizing Linear Recurrent Neural Nets Over Sequence Length},
booktitle = {International Conference on Learning Representations},
year = {2018}
}
@article{GuDao2023,
author = {Gu, Albert and Dao, Tri},
title = {Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
journal = {arXiv preprint arXiv:2312.00752},
year = {2023}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_scan-0.0.2.tar.gz.
File metadata
- Download URL: mlx_scan-0.0.2.tar.gz
- Upload date:
- Size: 58.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dcc1bb744a3307204ff955f329c7477adde426b8917195e0d9119611a3035cfa
|
|
| MD5 |
ac0191ffe0216e625c9638c2bee9fc43
|
|
| BLAKE2b-256 |
6f44506a6d7cc1b0a70be358928cb87658dc1bd55a78947e2cfcf147f22270fe
|
File details
Details for the file mlx_scan-0.0.2-py3-none-any.whl.
File metadata
- Download URL: mlx_scan-0.0.2-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
35fbd9853e97882f234ac058c2eb8f72984fe0ddcda47e91fcc244d5696e095d
|
|
| MD5 |
db3d21127f43a8cc5a0966aea58bfd6c
|
|
| BLAKE2b-256 |
314ef593ae5f4fc746307081ce2cccb75e41482b55a755b2d9b8a182dcf191f1
|