Associative scan implementation with MLX.
Project description
mlx-scan
Associative scan implementation with MLX.
Why associative scan?
A scan computes every prefix of a sequence:
[a, b, c, d] -> [a, a (*) b, a (*) b (*) c, a (*) b (*) c (*) d]
If (*) is associative, the prefixes can be computed with parallel-prefix algorithms such as Hillis-Steele or Blelloch scan. This matters for more than sums: matrix products, affine recurrences, and some sequence-model operations can be expressed as associative scans.
Quickstart
To install:
uv sync --extra dev
To test & lint:
uv run pytest
uv run ruff check .
Example
The canonical example is a cumulative sum:
import mlx.core as mx
from mlx_scan import associative_scan
x = mx.array([1, 2, 3, 4])
associative_scan(lambda a, b: a + b, x)
# array([1, 3, 6, 10], dtype=int32)
associative_scan accepts an MLX array or a pytree of MLX arrays.
Performance
This implementation is useful as a generic, MLX-native associative scan, but it is not a replacement for a compiler-native scan primitive. Benchmarks were run on macOS 14.2.1 arm64 with Python 3.10.12, MLX 0.31.2, and JAX 0.6.2 CPU-only.
For addition, mx.cumsum is the right baseline and is much faster than generic scan composition. On CPU, compiled median times in milliseconds:
| shape | MLX Hillis-Steele | MLX Blelloch | MLX cumsum | JAX lax scan |
|---|---|---|---|---|
| 1024 | 0.182 | 0.360 | 0.023 | 0.010 |
| 65536 | 0.536 | 0.481 | 0.080 | 0.149 |
| 512x1024 axis=1 | 3.145 | 1.481 | 0.483 | 0.823 |
For an SSM/linear-RNN style affine recurrence, there is no mx.cumsum equivalent. Each scan element is a pair (a, b) representing h -> a*h + b, with associative composition:
compose((a_left, b_left), (a_right, b_right)) = (
a_right * a_left,
a_right * b_left + b_right,
)
Compiled CPU median times in milliseconds:
| shape | MLX Hillis-Steele | MLX Blelloch | JAX lax scan |
|---|---|---|---|
| 1024 | 0.255 | 0.541 | 0.035 |
| 8192 | 0.345 | 0.929 | 0.106 |
| 128x256 axis=1 | 0.471 | 0.534 | 0.228 |
The main bottleneck is not plain Python loop overhead once compiled. It is the graph produced by expressing scan through repeated slice, concat, stack, reshape, and user-operator applications. MLX evaluates those composed graphs well enough to make the implementation usable, but it does not appear to apply the same scan-specific fusion/lowering that XLA applies to jax.lax.associative_scan.
For a Mamba-style prefill recurrence, the isolated recurrence state_t = decay_t * state_{t-1} + update_t is a better fit for the work-efficient Blelloch scan. With shape (batch=1, sequence, channels=256, state=16), compiled GPU median times in milliseconds:
| sequence | loop | Hillis-Steele | Blelloch |
|---|---|---|---|
| 128 | 2.654 | 0.888 | 0.995 |
| 512 | 8.469 | 1.973 | 0.899 |
| 1024 | 18.481 | 5.152 | 1.665 |
| 2048 | 43.396 | 10.206 | 1.898 |
Practical takeaways:
- For addition, use
mx.cumsum. - For generic associative functions, Blelloch is the best default.
- Hillis-Steele is still useful for small scans and cases where benchmarking shows it wins.
The benchmark tables above were generated with:
uv run python benchmarks/bench_scan.py --compiled --op add --device cpu \
--include-jax \
--algorithm hillis_steele \
--algorithm blelloch \
--algorithm mlx_cumsum \
--algorithm jax_lax_associative_scan \
--runs 50 --warmup 10
uv run python benchmarks/bench_scan.py --compiled --op affine --device cpu \
--include-jax \
--algorithm hillis_steele \
--algorithm blelloch \
--algorithm jax_lax_associative_scan \
--runs 50 --warmup 10
uv run python benchmarks/bench_mamba_prefill.py --compiled --device gpu \
--runs 10 --warmup 3
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_scan-0.0.1.tar.gz.
File metadata
- Download URL: mlx_scan-0.0.1.tar.gz
- Upload date:
- Size: 57.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
25e715d1be69200dce23ee6ed891515373ca2c0a30e201633d8710f724109a41
|
|
| MD5 |
663627551d6695258f74a17493942c00
|
|
| BLAKE2b-256 |
aa44b8aa62c051629633ab6881595d9854cceeeaec7a3a1889cf1a670e12af60
|
File details
Details for the file mlx_scan-0.0.1-py3-none-any.whl.
File metadata
- Download URL: mlx_scan-0.0.1-py3-none-any.whl
- Upload date:
- Size: 7.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7e59b484edea0cb085e260f318b2d0a91d9f12ff986059532157dbc75f412d76
|
|
| MD5 |
87f9fb22ba8999ea4dab7281bb6766ad
|
|
| BLAKE2b-256 |
bb37e2829cb402a5b7625d65dea30356256422a3ef413bad0c158f984f244906
|