Benchmark-backed Metal Flash Attention backends for MLX on Apple Silicon
Project description
mlx-mfa
mlx-mfa is a Metal Flash Attention + serving-oriented runtime layer for MLX on
Apple Silicon. It provides high-performance attention kernels, runtime helpers,
and cache abstractions for dense training/inference plus modern serving flows.
Current version: 2.31.0 — V34 NAX-direct rewrite. M5 Max V6 NAX reaches SDPA parity on D=128; SeedVR2-small at 0.89× actually beats SDPA.
Foreword
MLX Metal Flash Attention - Why?
I've been working on personal ports of Video Super Resolution and Video Reconstruction models for months, but always ended up frustrated by the slow inference in my M1 Max MacBook Pro. And to try to mitigate this without having to buy a brand-new, very expensive new M4, then M5 Max, I decided to at least try to port Flash Attention to Mac, hoping for better results. And having better results porting VSR/VR models to MLX than MPS, that's why I ended up doing it.
At this point, despite the lower than hoped for results, I'm still pretty satisfied with the results in my M1 Max MBP.
I'll be doing only reduced work on this project until June 2026, when I'll upgrade from my M1 Max to a M5 Max MBP, with which I expect to be able to obtain much better results, thanks to the improvements Apple has been adding to its silicon.
v2.31.0 ships the V34 NAX-direct rewrite. V6 NAX's forward hot path now
uses Apple's NAXFrag::mma and NAXTile<T, TQ, TD> primitives directly
(the pattern from steel_attention_nax.h), bypassing MPP cooperative_tensor
constraints that previously imposed execution_simdgroups<1>. Multi-SG
parallelism comes from per-SG row partitioning at the kernel level
(tm = 16 * TQ * sgid), not via cooperative_tensor distribution — so the
V33 cross-SG opacity issue disappears entirely.
The historic D=128 long-N gap is closed: production VSR/DiT shapes
that were stuck at 1.5–1.7× SDPA now run at SDPA parity. SeedVR2-small
at 0.89× SDPA actually beats SDPA, the first time V6 NAX has dipped
below 1.0× on a production shape. Numerics also improve 4–30× over legacy
because the manual simd_shuffle_xor row reductions on FP32 accumulators
inside NAXFrag::row_reduce are bit-exact, vs MPP's reduce_rows which
had tile-boundary FP rounding artifacts. Dispatch is shape-aware: V34 is
default for D=128 and D=64 N≥2048, legacy stays for D=64 small-N
(FlashVSR-dense regresses under V34 — root cause TBD).
v2.30.0 extended v2.29.0's V6 NAX work along three axes: (1) GQA
single-Otile — the BHND rewriter now handles Hq % Hk == 0 so GQA
shapes use the single-Otile kernel directly, gaining 7-14% over the
v2.29.0 legacy fallback; (2) dispatch v5 (the v6 attempt was reverted
after thermal-controlled re-bench); (3) tgmem allocation cleanup —
single-Otile + bypass no longer allocates the unused P_buf threadgroup
memory.
v2.29.0 shipped V6 NAX single-Otile for M5+ hardware: an Apple-style
single-buffer kernel (loopForwardSingleTile) with autoresearch-tuned
default tile config (BQ=16 universal, per-D BK/SG).
v2.27.0 added native Metal attn_bias kernel support (additive bias on
attention logits without SDPA fallback), a dispatch audit for 11 DiT/UNet
architectures, and varlen validation for token merging workflows.
See CHANGELOG.md for full details per version.
Thank you for your interest, and let me know if you've been able to improve on my work!
Current Repository Status
- V2 dense is the main production path.
- Strongest dense wins on M1 Max remain causal D=64/128 and tile-skip regimes (window/sparse).
- D=256 is narrow benchmark-backed only (not broad promotion).
- D=512 remains SDPA-default.
- Native dense backward was benchmarked and not promoted.
- Sage is a specialized decode backend (narrow, benchmark-gated use).
- V3/V4/V5 remain experimental/hardware-dependent.
- TurboQuant KV cache compression (Phase 1–4) production-ready.
- SVDQuantLinear W4A16 + optional SVD low-rank correction for DiT quantization.
- GNA native kernel inline 3D window attention (D=128, f16/bf16, forward-only).
- Native
attn_biasadditive bias on logits via Metal kernel (modes 1/2: per-KV and per-head per-KV broadcast). - Serving/runtime capability surface is now substantially expanded:
- paged KV + packed varlen query support
- paged continuous batching/remap
- explicit chunked prefill
- runtime-managed prefix reuse
- runtime speculative draft/verify flow
- deeper splitfuse runtime integration
- KV cache abstraction layer
- minimal real hybrid/offload-capable cache behavior (local offload tier)
- TurboQuant compressed KV serving (
create_decode_runtime(turboquant=True))
Limitations
- Main validation hardware is Apple M1 Max.
- Broad parity claims against CUDA FlashAttention ecosystems are not made.
- Some advanced paths are intentionally narrow, bridge-based, or explicit-only.
- Hybrid offload is currently a local offload milestone, not remote/ distributed cache infrastructure.
- Future major hardware-specific optimization work is deferred pending newer Apple hardware (M5+).
[See the v2.31.0 V6 NAX foreword above and the "Best M5 Max Benchmark Highlights (v2.31.0)" table below for current numbers.]
Best M1 Max Benchmark Highlights
Representative benchmark-backed outcomes (see RESULTS.md and
docs/benchmarks/RESULTS.md for details):
| Area | Representative result (M1 Max) | Interpretation |
|---|---|---|
| Dense causal V2 | up to ~1.82x vs SDPA (D=64, N=8192) | Primary production win regime |
| Dense causal V2 | up to ~1.75x vs SDPA (D=128, N=16384) | Strong long-sequence causal performance |
| Sliding window | up to ~21x vs full SDPA | Tile-skip regime remains strongest |
| D=256 | narrow causal long-N wins (for example ~1.16x at N=16384 f16) | Keep narrow policy only |
| D=512 | decision pass found no broad wins | SDPA-default remains correct |
Best M5 Max Benchmark Highlights (v2.31.0)
V6 NAX path on production VSR/DiT shapes (cross-session multi-run, iStat performance fan profile). The shape-aware dispatch picks V34 (NAX-direct) where it wins, legacy V6 NAX otherwise.
| Shape | D | Path | V6 NAX vs SDPA |
|---|---|---|---|
| FlashVSR-dense | 64 | legacy | 1.23× SDPA |
| LTX2-cross | 64 | V34 | 1.07× SDPA |
| SeedVR2-small | 128 | V34 | 0.89× SDPA ⭐ (beats SDPA) |
| CogVideoX | 128 | V34 | 1.03× SDPA (parity) |
| SeedVR2-large | 128 | V34 | 1.01× SDPA (parity) |
GQA shapes (Sprint B single-Otile path, legacy V6 NAX):
| Shape | V6 NAX vs SDPA |
|---|---|
| GQA-Hq32-Hk8 D=128 | 1.06× ⭐ |
| GQA-Hq16-Hk4 D=64 | 1.17× |
| GQA-Hq40-Hk8 D=128 | 1.16× |
| GQA-Hq8-Hk2 D=64 | 1.18× |
Numerical: V34 RMSE FP32 vs SDPA reference is 9e-7 to 4e-6 across all 5 shapes — 4–30× more stable than legacy V6 NAX (1.5e-5 to 6e-6). Manual simd_shuffle_xor row reductions on FP32 accumulators are bit-exact, vs MPP's reduce_rows which had tile-boundary FP rounding.
Serving/Runtime Capability Summary
| Capability | Maturity | Current status |
|---|---|---|
| Paged KV decode runtime | Fully usable | Explicit runtime/API usage; no broad auto-promotion |
| Paged + packed varlen queries | Production (fused kernel) | Single-dispatch fused kernel for all query/KV length combinations |
| Paged continuous batching remap | Fully usable | Explicit cache_batch_idx semantics + runtime helpers |
| Chunked prefill | Fully usable (scheduler-oriented) | Operational capability; not a throughput win on current matrix |
| Runtime prefix caching | Fully usable | Register/seed/reuse path integrated with runtime metadata |
| Runtime speculative decode | Fully usable (narrow) | speculative_step + verify integration; scheduler engine still future work |
| Splitfuse runtime integration | Narrow/conditional | Runtime path exists; performance remains shape-sensitive |
| Hybrid KV cache + local offload tier | Narrow/conditional milestone | Real hot/cold/offloaded behavior locally; remote offload future work |
| TurboQuant KV compression (Phase 4) | Production | 5.33× K compression, WHT fused in kernel (1.1–1.4× faster) |
| SVDQuantLinear | Production | W4A16 + rank-r FP16 correction; quantize_model() tree walker |
| GNA native kernel | Production | Inline 3D window attention (D=128); exact per-element masking |
Native attn_bias |
Production | Modes 1/2 via V2 STEEL; modes 0/3 SDPA fallback |
| External cache adapter layer | Experimental groundwork | Concrete local backend provided; external backend integrations pending |
Repository Guide
- Feature coverage:
docs/FEATURE_COVERAGE.md - API manual:
docs/API_MANUAL.md - Architecture:
docs/ARCHITECTURE.md - Inventory map:
docs/INVENTORY.md - Benchmark interpretation:
docs/benchmarks/RESULTS.md - Root benchmark summary:
RESULTS.md - Changelog:
CHANGELOG.md - Historical development archive:
devnotes/ - Examples:
examples/
Production vs Narrow vs Experimental
| Status | Components |
|---|---|
| Production | V2 dense causal small-D path; window/sparse tile-skip; SDPA fallback policy; TurboQuant KV compression; SVDQuantLinear; GNA native kernel; native attn_bias |
| Narrow / conditional | D=256 causal long-N policy; Sage decode regimes; splitfuse/page-native runtime paths; hybrid local offload behavior |
| Experimental | V3/V4/V5 families; external/LMCache-like backend extensions beyond local adapter |
Recommended Usage
- Use
backend="auto"for dense attention and let policy route between V2 and SDPA. - Use
create_decode_runtime(...)for serving flows instead of stitching helper calls manually. - Treat paged/packed/chunked/prefix/speculative features as explicit runtime capabilities.
- Use Sage as a specialized decode backend only when your workload matches the benchmark-backed regime.
Installation
pip install -e .
Minimal Usage
import mlx.core as mx
from mlx_mfa import flash_attention, flash_attention_gna, create_decode_runtime
from mlx_mfa import SVDQuantLinear, quantize_model
# Dense attention
q = mx.random.normal((1, 8, 1024, 128)).astype(mx.float16)
k = mx.random.normal((1, 8, 1024, 128)).astype(mx.float16)
v = mx.random.normal((1, 8, 1024, 128)).astype(mx.float16)
out = flash_attention(q, k, v, causal=True)
# Token merging proportional attention (native Metal, no SDPA fallback)
merge_counts = mx.ones((1, 1, 1, 1024), dtype=mx.float16)
merge_counts[..., :256] = 2.0 # first 256 tokens are merged pairs
bias = mx.log(merge_counts) # [1, 1, 1, N_kv] — mode 1 broadcast
out_biased = flash_attention(q, k, v, attn_bias=bias)
# GNA (Generalized Neighborhood Attention) — 3D window
# Video: 8 frames of 32x32, local 3D window, sliding
q_vid = mx.random.normal((1, 8, 8192, 128)).astype(mx.float16)
k_vid = mx.random.normal((1, 8, 8192, 128)).astype(mx.float16)
v_vid = mx.random.normal((1, 8, 8192, 128)).astype(mx.float16)
out_gna = flash_attention_gna(q_vid, k_vid, v_vid,
seq_shape=(8, 32, 32),
window_size=(2, 8, 8),
stride=(1, 1, 1))
# SVDQuantLinear — W4A16 + SVD low-rank correction
# (quantize_model replaces nn.Linear layers in-place)
# model = quantize_model(model, group_size=64, bits=4, rank=32)
# Serving-oriented runtime
rt = create_decode_runtime(
backend="auto",
paged=False,
quantized_kv=False,
B=1,
H_q=8,
H_kv=8,
D=128,
max_seq_len=4096,
)
out_prefill = rt.prefill(q, k, v)
out_step = rt.step(
mx.random.normal((1, 8, 1, 128)).astype(mx.float16),
mx.random.normal((1, 8, 1, 128)).astype(mx.float16),
mx.random.normal((1, 8, 1, 128)).astype(mx.float16),
)
License
MIT. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_mfa-2.31.0.tar.gz.
File metadata
- Download URL: mlx_mfa-2.31.0.tar.gz
- Upload date:
- Size: 5.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
56c19ba71f787729651f809af2c0156370a8840fb43d7edbd1cccc5f3b05dbba
|
|
| MD5 |
d872dea2da84f46f41faec1862922570
|
|
| BLAKE2b-256 |
0e3f9a1e5cac76d82cbb260d6107d49ab66e3181e7286c3f00accb76f5e5d636
|
File details
Details for the file mlx_mfa-2.31.0-cp311-cp311-macosx_26_0_arm64.whl.
File metadata
- Download URL: mlx_mfa-2.31.0-cp311-cp311-macosx_26_0_arm64.whl
- Upload date:
- Size: 466.4 kB
- Tags: CPython 3.11, macOS 26.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a8729dc000e423198022ef0c8079e83b0a4423d67f87b50c6bdb5a5187d3b5a0
|
|
| MD5 |
77e31f0eb49062976fb94f9e1acdd1f8
|
|
| BLAKE2b-256 |
f6a730513ffcd4376e6c3bfdd799a857eb9cf110e47d47974266bf0fba93dd7f
|