Memory-efficient Metal flash-attention for PyTorch MPS, with a guarded flash_attn drop-in shim
Project description
mtlflashattn
Memory-efficient, fast flash-attention for PyTorch on Apple Silicon (MPS) — with a
guarded flash_attn drop-in shim so existing code uses it automatically.
mtlflashattn never materializes the Lq×Lk score matrix (fixing OOM/SIGKILL on long
attention), and on M5 / macOS 27 it drives Apple's TensorOps / Neural Accelerators via
matmul2d to run 3–11× faster than the stock fused MPS SDPA — which, separately, is
silently numerically wrong past ~4k tokens. Pure-Python torch.mps.compile_shader; no C++/
ObjC/Swift extension, no .metallib, no xcrun at runtime.
Likely the first working
matmul2d/TensorOps flash-attention outside Apple.
Why
- OOM rescue. Stock attention allocates
Lq×Lkscores; at long sequence it blows past unified memory and the process is SIGKILL'd. This kernel is O(L·D) memory — it never forms the score matrix. - Correctness. The MPS fused
scaled_dot_product_attentiondiverges past ~4k tokens (per-element errors up to ~28 on real DiT q/k/v → grid artifacts / variance collapse). This kernel accumulates softmax/output state in fp32 and matches a chunked-fp32 reference to ~1e-4–6e-4 at 1k–32k tokens. - Speed. On M5 it uses the Neural Accelerators (TensorOps
matmul2d) with a register-resident online-softmax pipeline. - Zero-friction adoption. A guarded
flash_attnimport shim means libraries thatfrom flash_attn import flash_attn_funcget the Metal version on Apple Silicon with no code change — while a real CUDAflash_attninstall always wins.
Requirements
- Apple Silicon Mac,
torch >= 2.5(torch.mps.compile_shader). - v2 / v2r (TensorOps): M5+ and macOS 26+ (MetalPerformancePrimitives). Auto-detected.
- v1 (
simdgroup_matrix): any Apple Silicon (M1+), used when TensorOps is unavailable. - v0 / torch fallback: anywhere MPS runs.
Install
pip install mtlflashattn # or: uv add mtlflashattn
The guarded flash_attn shim auto-activates at interpreter start via a .pth file. Kill it
with MTLFLASHATTN_SHIM=off.
Usage
1. As the flash_attn drop-in (most code needs nothing)
import torch
from flash_attn import flash_attn_func # served by mtlflashattn on MPS
q = torch.randn(1, 8192, 16, 64, device="mps", dtype=torch.bfloat16) # [B, S, H, D]
k = torch.randn_like(q); v = torch.randn_like(q)
out = flash_attn_func(q, k, v, causal=True)
Exposes the CUDA flash-attn surface: flash_attn_func, flash_attn_varlen_func,
flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func, and the flash_attn.flash_attn_interface submodule.
Unsupported features (dropout, alibi, softcap, return_attn_probs, sliding window, D>128,
backward) raise NotImplementedError so callers fall back rather than get wrong results.
2. Direct API
import metal_flash_attn as mfa
out = mfa.flash_attn_func(q, k, v, causal=True, softmax_scale=0.125)
3. As an F.scaled_dot_product_attention patch
import metal_flash_attn.sdpa as sdpa
sdpa.install() # reroutes F.scaled_dot_product_attention on MPS, gated; opt-in
...
sdpa.uninstall()
Fires on three gates: correctness (max(Lq,Lk) ≥ MTLFLASHATTN_SDPA_MIN_SEQ, default 4096),
speed (a fast TensorOps tier above MTLFLASHATTN_SDPA_FAST_MIN_SEQ, default 1024), and
memory (MTLFLASHATTN_SDPA_MIN_GB, default 12). Tiny attention stays on stock. Never
crashes the caller — any kernel error falls through to the original op.
Kernel tiers
Selected automatically by dtype, head dim, and sequence length (MTLFLASHATTN_KERNEL=auto):
| Tier | Hardware | Notes |
|---|---|---|
| v2r | M5+ / macOS 27+ | TensorOps matmul2d, register-resident P (no threadgroup round-trip). Fastest. bf16 all D; fp32 D≤64; fp16 D≤64. Gated to Lk≥256. |
| v2 | M5+ / macOS 26+ | TensorOps matmul2d, threadgroup-staged P. fp16/bf16/fp32 (v2_fp32/v2_bf16). |
| v1 | M1+ | simdgroup_matrix 8×8 FA-2. fp16 fallback when TensorOps is unavailable. |
| torch | any MPS | Chunked fp32 matmul-softmax-matmul. The safe fp32 short-sequence path. |
| v0 | any MPS | Scalar one-thread-per-row. Exact, memory-safe debug baseline. |
Force a tier with MTLFLASHATTN_KERNEL=v0|v1|v2|v2_fp32|v2_bf16|v2_dtype|torch.
Benchmarks
M5 Max, macOS 27.0, torch 2.12, B=1 H=16, vs stock fused SDPA (ratio = flash/stock, <1 is faster; effective TF/s in parens).
fp16, non-causal:
| shape | stock | v1 | v2 / v2r |
|---|---|---|---|
| D=64 L=8k | 92 ms (3.0) | 0.55× (5.4) | 0.09× (30.8 TF/s, v2r) |
| D=128 L=8k | 101 ms (5.4) | 0.97× (5.6) | 0.29× (18.9 TF/s, v2) |
dtype-specialized v2r (register-resident P), effective TF/s:
| shape | bf16 | fp32 | precision |
|---|---|---|---|
| D=64 L=8k | 30.6 (2.5× over TG round-trip) | 12.5 (1.45×) | fp32 bit-exact; bf16 ~bf16 noise |
| D=128 L=26k | 22.5 (1.23×) | 8.8 (TG) | bf16 ~bf16 noise |
Causal adds block-skipping the MPS path lacks: fp16 D=64 L=8k v2r ≈ 23 TF/s (~20× stock).
Reproduce: python bench/bench_attn.py (fp16) and python dev/bench_dtype_kernels.py (all dtypes).
Diagnostics
MTLFLASHATTN_TRACE=1 makes every flash_attn_forward accumulate per
(dtype, D, Lq, Lk, causal, resolved-kernel) call counts and print a summary to stderr at exit
— so you can see exactly which tier each stage of a real workload hits. Zero overhead when off.
[MTLFLASHATTN_TRACE] 4560 attention call(s), 6 distinct shape/kernel combos:
calls=990 bfloat16 D=128 Lq=26136 Lk=26136 causal=F -> v2r(bfloat)
calls=990 bfloat16 D=128 Lq=26136 Lk=5 causal=F -> v2_bf16(TG)
...
Environment knobs
| Variable | Default | Effect |
|---|---|---|
MTLFLASHATTN_KERNEL |
auto |
Force a tier. |
MTLFLASHATTN_SHIM |
auto |
off disables the flash_attn import shim. |
MTLFLASHATTN_TRACE |
off | 1 prints a per-shape/tier call summary at exit. |
MTLFLASHATTN_V2_PREUSE |
auto |
Force/disable the register-resident-P (v2r) path. |
MTLFLASHATTN_V2_FP32_MIN_SEQ |
2048 |
fp32 length gate: above → v2_fp32, below → torch fallback. |
MTLFLASHATTN_TORCH_CHUNK |
2048 |
Query-chunk size for the torch fallback. |
MTLFLASHATTN_SDPA / _MIN_GB / _MIN_SEQ / _FAST_MIN_SEQ |
— | SDPA-patch gating. |
Scope
Inference forward pass only (no backward). Supports softmax_scale, bottom-right causal,
GQA/MQA, fp16/bf16/fp32, D≤128. Sliding window, dropout, alibi, softcap, KV-cache decode, and
FlexAttention are not implemented (the shim raises so callers fall back).
How it works / engineering notes
The full build log — the TensorOps/MPP API gotchas discovered empirically on macOS 27 beta
(tensor::slice() reads wrong data as a matmul2d operand; operand extents must be clamped;
tensor destinations always accumulate; reduce_rows requires single-simdgroup scope and its
row-reduction output is per-thread co-located with the source — the reduced value for row r
sits on every lane that owns row r's columns, with the reduction's row index at idx[0] vs the
source's at idx[1]; the register-resident P recipe forces the S element type to match the left
input), the per-tier design, and the speed/precision analysis are documented inline in the kernel
source (metal_flash_attn/_kernel.py).
License
MIT.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mtlflashattn-0.1.0.tar.gz.
File metadata
- Download URL: mtlflashattn-0.1.0.tar.gz
- Upload date:
- Size: 60.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e0e874efcf70b4e9707a802a8f2dcdfc9624cf2ef444a046ab315896fe34c6d9
|
|
| MD5 |
59526d4f52e83619d8255a9c43b65570
|
|
| BLAKE2b-256 |
b81dbc9fd884908f6ee37ac9cd721a691a4ffe676fc19b6522376e4de6a231b9
|
File details
Details for the file mtlflashattn-0.1.0-py3-none-any.whl.
File metadata
- Download URL: mtlflashattn-0.1.0-py3-none-any.whl
- Upload date:
- Size: 23.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
39a4b09b91d5e10f50f62ed01b7af1a8ec818370160b19ff9b0bd9b8579a328e
|
|
| MD5 |
e8e024f2a4b9633c2fdd8df116c97da9
|
|
| BLAKE2b-256 |
21fe98da7c12e2f13ffe9cf85d0d83b574b1e140c9b8e1110abfee633df0c906
|