Skip to main content

Associative scan implementation with MLX.

Project description

mlx-scan

Associative scan implementation with MLX.

Quickstart

To install:

uv add mlx-scan       # uv
pip install mlx-scan  # pip

For local development:

uv sync --extra dev --extra bench

To test and lint:

uv run pytest
uv run ruff check .

Associative Scan

A scan computes every prefix of a sequence:

[a, b, c, d] -> [a, a (*) b, a (*) b (*) c, a (*) b (*) c (*) d]

If the binary operator (*) is associative, the prefixes can be computed with parallel-prefix algorithms such as Hillis-Steele or Blelloch scan. This matters for more than sums: matrix products, affine recurrences, and sequence-model operations can be expressed with scan-like computations.

Example

The canonical example is a cumulative sum:

import mlx.core as mx
from mlx_scan import associative_scan

x = mx.array([1, 2, 3, 4])
associative_scan(lambda a, b: a + b, x)
# array([1, 3, 6, 10], dtype=int32)

associative_scan accepts an MLX array or a pytree of MLX arrays.

Performance

This implementation is useful as a generic, MLX-native associative scan, but it is not a replacement for a compiler-native scan primitive. Benchmarks were run on macOS 14.2.1 arm64 with Python 3.10.12, MLX 0.31.2, and JAX 0.6.2 CPU-only.

For addition, the baseline mx.cumsum is much faster than generic scan composition. On CPU, compiled median times in milliseconds:

shape MLX Hillis-Steele MLX Blelloch MLX cumsum JAX lax scan
1024 0.182 0.360 0.023 0.010
65536 0.536 0.481 0.080 0.149
512x1024 axis=1 3.145 1.481 0.483 0.823

For an SSM/linear-RNN style affine recurrence, there is no mx.cumsum equivalent. Each scan element is a pair (a, b) representing h -> a*h + b, with associative composition:

compose((a_left, b_left), (a_right, b_right)) = (
    a_right * a_left,
    a_right * b_left + b_right,
)

Compiled CPU median times in milliseconds:

shape MLX Hillis-Steele MLX Blelloch JAX lax scan
1024 0.255 0.541 0.035
8192 0.345 0.929 0.106
128x256 axis=1 0.471 0.534 0.228

The main bottleneck is not plain Python loop overhead once compiled; it is the graph produced by expressing scan through repeated slice, concat, stack, reshape, and user-operator applications. MLX evaluates those composed graphs well enough to make the implementation usable, but it does not appear to apply the same scan-specific fusion/lowering that XLA applies to jax.lax.associative_scan.

For a Mamba-style prefill recurrence, adapted from the SSM loop in mlx-lm's mamba.py, the isolated recurrence state_t = decay_t * state_{t-1} + update_t is a better fit for the work-efficient Blelloch scan. With shape (batch=1, sequence, channels=256, state=16), compiled GPU median times in milliseconds:

sequence loop Hillis-Steele Blelloch
128 2.654 0.888 0.995
512 8.469 1.973 0.899
1024 18.481 5.152 1.665
2048 43.396 10.206 1.898

The benchmark tables above were generated with:

uv run python benchmarks/bench_scan.py --compiled --op add --device cpu \
  --include-jax \
  --algorithm hillis_steele \
  --algorithm blelloch \
  --algorithm mlx_cumsum \
  --algorithm jax_lax_associative_scan \
  --runs 50 --warmup 10

uv run python benchmarks/bench_scan.py --compiled --op affine --device cpu \
  --include-jax \
  --algorithm hillis_steele \
  --algorithm blelloch \
  --algorithm jax_lax_associative_scan \
  --runs 50 --warmup 10

uv run python benchmarks/bench_mamba_prefill.py --compiled --device gpu \
  --runs 10 --warmup 3

References

@article{HillisSteele1986,
  author = {Hillis, W. Daniel and Steele, Guy L., Jr.},
  title = {Data Parallel Algorithms},
  journal = {Communications of the ACM},
  volume = {29},
  number = {12},
  pages = {1170--1183},
  year = {1986},
  doi = {10.1145/7902.7903}
}
@incollection{Blelloch1993,
  author = {Blelloch, Guy E.},
  title = {Prefix Sums and Their Applications},
  booktitle = {Synthesis of Parallel Algorithms},
  publisher = {Morgan Kaufmann},
  year = {1993}
}
@inproceedings{MartinCundy2018,
  author = {Martin, Eric and Cundy, Chris},
  title = {Parallelizing Linear Recurrent Neural Nets Over Sequence Length},
  booktitle = {International Conference on Learning Representations},
  year = {2018}
}
@article{GuDao2023,
  author = {Gu, Albert and Dao, Tri},
  title = {Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  journal = {arXiv preprint arXiv:2312.00752},
  year = {2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_scan-0.0.2.tar.gz (58.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_scan-0.0.2-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file mlx_scan-0.0.2.tar.gz.

File metadata

  • Download URL: mlx_scan-0.0.2.tar.gz
  • Upload date:
  • Size: 58.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for mlx_scan-0.0.2.tar.gz
Algorithm Hash digest
SHA256 dcc1bb744a3307204ff955f329c7477adde426b8917195e0d9119611a3035cfa
MD5 ac0191ffe0216e625c9638c2bee9fc43
BLAKE2b-256 6f44506a6d7cc1b0a70be358928cb87658dc1bd55a78947e2cfcf147f22270fe

See more details on using hashes here.

File details

Details for the file mlx_scan-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: mlx_scan-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for mlx_scan-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 35fbd9853e97882f234ac058c2eb8f72984fe0ddcda47e91fcc244d5696e095d
MD5 db3d21127f43a8cc5a0966aea58bfd6c
BLAKE2b-256 314ef593ae5f4fc746307081ce2cccb75e41482b55a755b2d9b8a182dcf191f1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page