Skip to main content

Associative scan implementation with MLX.

Project description

mlx-scan

Associative scan implementation with MLX.

Why associative scan?

A scan computes every prefix of a sequence:

[a, b, c, d] -> [a, a (*) b, a (*) b (*) c, a (*) b (*) c (*) d]

If (*) is associative, the prefixes can be computed with parallel-prefix algorithms such as Hillis-Steele or Blelloch scan. This matters for more than sums: matrix products, affine recurrences, and some sequence-model operations can be expressed as associative scans.

Quickstart

To install:

uv sync --extra dev

To test & lint:

uv run pytest
uv run ruff check .

Example

The canonical example is a cumulative sum:

import mlx.core as mx
from mlx_scan import associative_scan

x = mx.array([1, 2, 3, 4])
associative_scan(lambda a, b: a + b, x)
# array([1, 3, 6, 10], dtype=int32)

associative_scan accepts an MLX array or a pytree of MLX arrays.

Performance

This implementation is useful as a generic, MLX-native associative scan, but it is not a replacement for a compiler-native scan primitive. Benchmarks were run on macOS 14.2.1 arm64 with Python 3.10.12, MLX 0.31.2, and JAX 0.6.2 CPU-only.

For addition, mx.cumsum is the right baseline and is much faster than generic scan composition. On CPU, compiled median times in milliseconds:

shape MLX Hillis-Steele MLX Blelloch MLX cumsum JAX lax scan
1024 0.182 0.360 0.023 0.010
65536 0.536 0.481 0.080 0.149
512x1024 axis=1 3.145 1.481 0.483 0.823

For an SSM/linear-RNN style affine recurrence, there is no mx.cumsum equivalent. Each scan element is a pair (a, b) representing h -> a*h + b, with associative composition:

compose((a_left, b_left), (a_right, b_right)) = (
    a_right * a_left,
    a_right * b_left + b_right,
)

Compiled CPU median times in milliseconds:

shape MLX Hillis-Steele MLX Blelloch JAX lax scan
1024 0.255 0.541 0.035
8192 0.345 0.929 0.106
128x256 axis=1 0.471 0.534 0.228

The main bottleneck is not plain Python loop overhead once compiled. It is the graph produced by expressing scan through repeated slice, concat, stack, reshape, and user-operator applications. MLX evaluates those composed graphs well enough to make the implementation usable, but it does not appear to apply the same scan-specific fusion/lowering that XLA applies to jax.lax.associative_scan.

For a Mamba-style prefill recurrence, the isolated recurrence state_t = decay_t * state_{t-1} + update_t is a better fit for the work-efficient Blelloch scan. With shape (batch=1, sequence, channels=256, state=16), compiled GPU median times in milliseconds:

sequence loop Hillis-Steele Blelloch
128 2.654 0.888 0.995
512 8.469 1.973 0.899
1024 18.481 5.152 1.665
2048 43.396 10.206 1.898

Practical takeaways:

  • For addition, use mx.cumsum.
  • For generic associative functions, Blelloch is the best default.
  • Hillis-Steele is still useful for small scans and cases where benchmarking shows it wins.

The benchmark tables above were generated with:

uv run python benchmarks/bench_scan.py --compiled --op add --device cpu \
  --include-jax \
  --algorithm hillis_steele \
  --algorithm blelloch \
  --algorithm mlx_cumsum \
  --algorithm jax_lax_associative_scan \
  --runs 50 --warmup 10

uv run python benchmarks/bench_scan.py --compiled --op affine --device cpu \
  --include-jax \
  --algorithm hillis_steele \
  --algorithm blelloch \
  --algorithm jax_lax_associative_scan \
  --runs 50 --warmup 10

uv run python benchmarks/bench_mamba_prefill.py --compiled --device gpu \
  --runs 10 --warmup 3

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_scan-0.0.1.tar.gz (57.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_scan-0.0.1-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file mlx_scan-0.0.1.tar.gz.

File metadata

  • Download URL: mlx_scan-0.0.1.tar.gz
  • Upload date:
  • Size: 57.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for mlx_scan-0.0.1.tar.gz
Algorithm Hash digest
SHA256 25e715d1be69200dce23ee6ed891515373ca2c0a30e201633d8710f724109a41
MD5 663627551d6695258f74a17493942c00
BLAKE2b-256 aa44b8aa62c051629633ab6881595d9854cceeeaec7a3a1889cf1a670e12af60

See more details on using hashes here.

File details

Details for the file mlx_scan-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: mlx_scan-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.4

File hashes

Hashes for mlx_scan-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7e59b484edea0cb085e260f318b2d0a91d9f12ff986059532157dbc75f412d76
MD5 87f9fb22ba8999ea4dab7281bb6766ad
BLAKE2b-256 bb37e2829cb402a5b7625d65dea30356256422a3ef413bad0c158f984f244906

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page