Fast drop-in scaled_dot_product_attention for Apple silicon (MPS) — wraps Apple's native MPSGraph SDPA with a zero-copy C++ bridge. 5-7x inference, 2-2.5x training.
Project description
mps-sdpa
Fast drop-in scaled_dot_product_attention for Apple silicon (MPS). Wraps
Apple's native MPSGraph.scaledDotProductAttention op with a zero-copy C++ /
Objective-C++ bridge. 5–7× inference speedup, 2–2.5× training speedup, 16–170×
less driver memory per call — with identical math and checkpoint compatibility.
from mps_sdpa import sdpa_opt
# Drop-in replacement — same signature as torch.nn.functional.scaled_dot_product_attention
out = sdpa_opt(query, key, value,
attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
backend="auto")
Why
PyTorch's MPS backend dispatches torch.nn.functional.scaled_dot_product_attention
to sdpa_general_mps, which builds a naive matmul → softmax → matmul graph.
It does not call Apple's dedicated MPSGraph.scaledDotProductAttention op,
which is present on macOS 15+ and is significantly faster.
mps-sdpa wraps that op directly, with an autograd.Function for training,
shape-threshold auto-calibration, graceful fallbacks, and thread-safe graph
caching. The C++ extension uses ATen's getMTLBufferStorage to hand torch
tensors to MPSGraph without CPU memcpy.
Measured performance
All numbers on M4 / macOS 26.4.1 / torch 2.13 nightly, bfloat16.
Inference (forward only, B=1, H=8, D=64)
| L | stock | mps-sdpa | speedup |
|---|---|---|---|
| 1024 | 5.78 ms | 0.90 ms | 6.42× |
| 2048 | 19.0 ms | 3.82 ms | 4.97× |
| 4096 | 76.1 ms | 11.79 ms | 6.45× |
| 8192 | 317 ms | 44.3 ms | 7.17× |
Weighted geomean across a realistic audio-model shape suite: 4.88×.
Training (forward + backward, same shapes)
| L | stock | mps-sdpa | speedup |
|---|---|---|---|
| 1024 | 9.93 ms | 5.06 ms | 1.96× |
| 2048 | 38.6 ms | 17.1 ms | 2.25× |
| 4096 | 154 ms | 64.8 ms | 2.38× |
| 8192 | 608 ms | 247 ms | 2.46× |
Training with dropout (dropout_p=0.1)
| L | stock | mps-sdpa | speedup |
|---|---|---|---|
| 1024 | 14.19 ms | 7.63 ms | 1.86× |
| 2048 | 55.83 ms | 28.49 ms | 1.96× |
| 4096 | 228 ms | 101 ms | 2.26× |
Driver memory per call
Apple's fused op doesn't materialize the [Lq, Lkv] attention matrix. The
zero-copy C++ bridge removes the CPU-side intermediate buffer too.
| L | stock | mps-sdpa | reduction |
|---|---|---|---|
| 2048 | 1024 MB | <1 MB | ≫128× |
| 4096 | 1024 MB | <1 MB | ≫64× |
| 8192 | 1024 MB | 32 MB | 32× |
Install
pip install mps-sdpa
Requires macOS 15+ on Apple silicon (M1–M4) with PyTorch ≥ 2.11. ninja is
pulled in automatically for the zero-copy backend's JIT compile (~6s on first
import, then cached to ~/.cache/torch_extensions).
Development install from source:
git clone https://github.com/crlandsc/mps-sdpa.git
cd mps-sdpa
pip install -e ".[dev]"
pytest tests/
Usage
The API exactly mirrors torch.nn.functional.scaled_dot_product_attention:
import torch
from mps_sdpa import sdpa_opt
q = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
k = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
v = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
out = sdpa_opt(q, k, v) # basic
out = sdpa_opt(q, k, v, is_causal=True) # causal mask
out = sdpa_opt(q, k, v, attn_mask=bool_mask) # bool mask
out = sdpa_opt(q, k, v, dropout_p=0.1) # training dropout
out = sdpa_opt(q, k, v, backend="mpsgraph_zc") # force specific backend
As a drop-in swap inside an existing model:
import torch.nn as nn
import torch.nn.functional as F
from mps_sdpa import sdpa_opt
class MyAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
self.dropout_p = 0.1
def forward(self, x):
B, L, _ = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
p = self.dropout_p if self.training else 0.0 # standard pattern
y = sdpa_opt(q, k, v, dropout_p=p) # <— was F.scaled_dot_product_attention
y = y.transpose(1, 2).reshape(B, L, self.d_model)
return self.out(y)
Gate the behavior at model-construction time (useful for A/B comparison):
class Attn(nn.Module):
def __init__(self, ..., use_mps_sdpa_opt: bool = False):
self.use_opt = use_mps_sdpa_opt
def forward(self, q, k, v):
if self.use_opt:
return sdpa_opt(q, k, v, ...)
return F.scaled_dot_product_attention(q, k, v, ...)
Checkpoints are fully interchangeable — use_opt changes behavior only, not
parameters.
CLI
mps-sdpa self-test # quick validation (<1s)
mps-sdpa list-backends # show available backends
mps-sdpa correctness --backend mpsgraph_zc --device mps --suite realistic
mps-sdpa benchmark --backend mpsgraph_zc --baseline stock --device mps --suite realistic
Backends
backend="auto" picks the best available. Available backends (in preference order):
| Name | Impl | Best for |
|---|---|---|
mpsgraph_zc |
Obj-C++ torch extension, getMTLBufferStorage zero-copy | Default — forward, training, masks, dropout |
mpsgraph |
pyobjc + CPU-copy bridge | Fallback when the ext can't build |
stock |
torch.nn.functional.scaled_dot_product_attention |
Always available final fallback |
metal_proto |
naive single-thread Metal kernel via torch.mps.compile_shader |
Reference / experimentation, not auto-selected |
Fallbacks are transparent — a call that can't go through mpsgraph_zc (e.g.
short seq, CPU device, unsupported dtype) routes down the list automatically.
Correctness contract
Same tolerance bar as CUDA flash-attention vs math:
- fp32: atol=5e-6, rtol=5e-5
- fp16/bf16: atol=5e-3, rtol=5e-2
- gradients: 2× the forward tolerance
Not bitwise identical. Checkpoints trained with one backend load cleanly into the other (verified).
Scope / supported configs
| Category | Maintainer-tested | Should work (untested by maintainer) | Not supported |
|---|---|---|---|
| Apple silicon | M4 mini, M3 Max | M1, M2 (should work via auto-calibration) | — |
| macOS | 26.x | 15.x (all API surfaces present) | 14.x (op missing — backend registers unavailable) |
| torch | 2.11 stable + 2.13 nightly | — | — |
| dtypes | bf16, fp16, fp32 | — | fp64 (MPS doesn't support it) |
Maintainer testing is M-series Apple silicon on macOS 26+ only. Reports from other configurations are welcome but no commitment to test them.
Architecture
See ARCHITECTURE.md for the internal structure: backend registry, graph cache, threshold auto-calibration, C++ extension build system, autograd wiring, dropout path.
What doesn't work (yet)
- GQA (Hq ≠ Hkv): routes to stock with
repeat_interleave. mpsgraph op is MHA-only. One-time warning. - Second-order gradients (
create_graph=Trueon our output's grad): raises a clear error. Backward uses MPSGraph which is opaque to torch autograd; true higher-order would require a differentiable backward impl. - torch.compile (full graph capture): untested. The JIT-compiled C++ extension is compatible with eager mode; inductor integration is future work.
- macOS 14: the
MPSGraph.scaledDotProductAttentionmethod doesn't exist on Sequoia's predecessor. Backend registers as unavailable with a clear reason; calls fall back to stock.
Correctness — what's tested
235 tests across 39 files. Highlights:
- Shape matrix: D ∈ {32, 64, 96, 128, 192, 256}; H ∈ {1..32}; B ∈ {1..32}; Lq, Lkv ∈ {powers of 2, 777, 1345, 3141}.
- Masks: causal, bool, additive float, all-True, mostly-False, per-head [B,H,Lq,Lkv], per-batch, broadcast variants, dtype coercion.
- Autograd: partial requires_grad subsets, AMP autocast (bf16/fp16),
torch.utils.checkpointre-entry, retain_graph, grad accumulation, second-order detection, once_differentiable guard. - Training: generic transformer smoke (forward + backward), 1000-step long-horizon convergence, mid-run backend toggle, fp16 500-step AMP.
- Numerical extremes: Q/K std ∈ {10, 1e-4}, grad_out ∈ {1e3, 1e4} — all NaN/Inf-free.
- Edge cases: degenerate shapes, non-contiguous inputs, strided views, cache thrash (9 shapes), 200-call leak probe, 16k-seq OOM recovery.
See tests/ for the full test suite.
Citation / credit
- Apple's fused SDPA op:
MPSGraph.scaledDotProductAttention...(macOS 15+). - Related prior art:
philipturner/metal-flash-attention(Swift, not vetted here); PyPImps-flash-attn(8 stars, unverified). Neither is used or wrapped.
License
MIT.
Contributing
See CONTRIBUTING.md. Bug reports with reproducible shapes
are the most helpful; the full-suite pytest tests/ should pass on
any supported config.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mps_sdpa-0.1.1.tar.gz.
File metadata
- Download URL: mps_sdpa-0.1.1.tar.gz
- Upload date:
- Size: 72.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bbb9df150d5dd1907014825077dd81a27ff21241cc7c3c2d9125b70fd51a4763
|
|
| MD5 |
9653dde3c8414689ccc1c501c773257f
|
|
| BLAKE2b-256 |
fd39e97b7c60bd3c7764790987508e41221a0638acb12e626a317fe595c19051
|
Provenance
The following attestation bundles were made for mps_sdpa-0.1.1.tar.gz:
Publisher:
pypi.yml on crlandsc/mps-sdpa
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
mps_sdpa-0.1.1.tar.gz -
Subject digest:
bbb9df150d5dd1907014825077dd81a27ff21241cc7c3c2d9125b70fd51a4763 - Sigstore transparency entry: 1404141906
- Sigstore integration time:
-
Permalink:
crlandsc/mps-sdpa@787fe5fc26200ff06ab1dae4932144f167360252 -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/crlandsc
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@787fe5fc26200ff06ab1dae4932144f167360252 -
Trigger Event:
push
-
Statement type:
File details
Details for the file mps_sdpa-0.1.1-py3-none-any.whl.
File metadata
- Download URL: mps_sdpa-0.1.1-py3-none-any.whl
- Upload date:
- Size: 54.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b4931aab886bf66d0676f04cb90d0b06f6f4530692bfdc2e54cfde2ec9c78f9
|
|
| MD5 |
8fa3f6a8d44266be54c21d23ed33869c
|
|
| BLAKE2b-256 |
f9b38888260df1a6c819525d88fe94059a4ff5923ac0222da01d60fcaef9392b
|
Provenance
The following attestation bundles were made for mps_sdpa-0.1.1-py3-none-any.whl:
Publisher:
pypi.yml on crlandsc/mps-sdpa
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
mps_sdpa-0.1.1-py3-none-any.whl -
Subject digest:
8b4931aab886bf66d0676f04cb90d0b06f6f4530692bfdc2e54cfde2ec9c78f9 - Sigstore transparency entry: 1404142109
- Sigstore integration time:
-
Permalink:
crlandsc/mps-sdpa@787fe5fc26200ff06ab1dae4932144f167360252 -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/crlandsc
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi.yml@787fe5fc26200ff06ab1dae4932144f167360252 -
Trigger Event:
push
-
Statement type: