Skip to main content

Fast drop-in scaled_dot_product_attention for Apple silicon (MPS) — wraps Apple's native MPSGraph SDPA with a zero-copy C++ bridge. 5-7x inference, 2-2.5x training.

Project description

mps-sdpa

Fast drop-in scaled_dot_product_attention for Apple silicon (MPS). Wraps Apple's native MPSGraph.scaledDotProductAttention op with a zero-copy C++ / Objective-C++ bridge. 5–7× inference speedup, 2–2.5× training speedup, 16–170× less driver memory per call — with identical math and checkpoint compatibility.

from mps_sdpa import sdpa_opt

# Drop-in replacement — same signature as torch.nn.functional.scaled_dot_product_attention
out = sdpa_opt(query, key, value,
               attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
               backend="auto")

Why

PyTorch's MPS backend dispatches torch.nn.functional.scaled_dot_product_attention to sdpa_general_mps, which builds a naive matmul → softmax → matmul graph. It does not call Apple's dedicated MPSGraph.scaledDotProductAttention op, which is present on macOS 15+ and is significantly faster.

mps-sdpa wraps that op directly, with an autograd.Function for training, shape-threshold auto-calibration, graceful fallbacks, and thread-safe graph caching. The C++ extension uses ATen's getMTLBufferStorage to hand torch tensors to MPSGraph without CPU memcpy.

Measured performance

All numbers on M4 / macOS 26.4.1 / torch 2.13 nightly, bfloat16.

Inference (forward only, B=1, H=8, D=64)

L stock mps-sdpa speedup
1024 5.78 ms 0.90 ms 6.42×
2048 19.0 ms 3.82 ms 4.97×
4096 76.1 ms 11.79 ms 6.45×
8192 317 ms 44.3 ms 7.17×

Weighted geomean across a realistic audio-model shape suite: 4.88×.

Training (forward + backward, same shapes)

L stock mps-sdpa speedup
1024 9.93 ms 5.06 ms 1.96×
2048 38.6 ms 17.1 ms 2.25×
4096 154 ms 64.8 ms 2.38×
8192 608 ms 247 ms 2.46×

Training with dropout (dropout_p=0.1)

L stock mps-sdpa speedup
1024 14.19 ms 7.63 ms 1.86×
2048 55.83 ms 28.49 ms 1.96×
4096 228 ms 101 ms 2.26×

Driver memory per call

Apple's fused op doesn't materialize the [Lq, Lkv] attention matrix. The zero-copy C++ bridge removes the CPU-side intermediate buffer too.

L stock mps-sdpa reduction
2048 1024 MB <1 MB ≫128×
4096 1024 MB <1 MB ≫64×
8192 1024 MB 32 MB 32×

Install

pip install mps-sdpa

Requires macOS 15+ on Apple silicon (M1–M4) with PyTorch ≥ 2.11. ninja is pulled in automatically for the zero-copy backend's JIT compile (~6s on first import, then cached to ~/.cache/torch_extensions).

Development install from source:

git clone https://github.com/crlandsc/mps-sdpa.git
cd mps-sdpa
pip install -e ".[dev]"
pytest tests/

Usage

The API exactly mirrors torch.nn.functional.scaled_dot_product_attention:

import torch
from mps_sdpa import sdpa_opt

q = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
k = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
v = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")

out = sdpa_opt(q, k, v)                                  # basic
out = sdpa_opt(q, k, v, is_causal=True)                  # causal mask
out = sdpa_opt(q, k, v, attn_mask=bool_mask)             # bool mask
out = sdpa_opt(q, k, v, dropout_p=0.1)                   # training dropout
out = sdpa_opt(q, k, v, backend="mpsgraph_zc")           # force specific backend

As a drop-in swap inside an existing model:

import torch.nn as nn
import torch.nn.functional as F
from mps_sdpa import sdpa_opt

class MyAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.dropout_p = 0.1

    def forward(self, x):
        B, L, _ = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        p = self.dropout_p if self.training else 0.0  # standard pattern
        y = sdpa_opt(q, k, v, dropout_p=p)            # <— was F.scaled_dot_product_attention
        y = y.transpose(1, 2).reshape(B, L, self.d_model)
        return self.out(y)

Gate the behavior at model-construction time (useful for A/B comparison):

class Attn(nn.Module):
    def __init__(self, ..., use_mps_sdpa_opt: bool = False):
        self.use_opt = use_mps_sdpa_opt

    def forward(self, q, k, v):
        if self.use_opt:
            return sdpa_opt(q, k, v, ...)
        return F.scaled_dot_product_attention(q, k, v, ...)

Checkpoints are fully interchangeable — use_opt changes behavior only, not parameters.

CLI

mps-sdpa self-test                      # quick validation (<1s)
mps-sdpa list-backends                  # show available backends
mps-sdpa correctness --backend mpsgraph_zc --device mps --suite realistic
mps-sdpa benchmark --backend mpsgraph_zc --baseline stock --device mps --suite realistic

Backends

backend="auto" picks the best available. Available backends (in preference order):

Name Impl Best for
mpsgraph_zc Obj-C++ torch extension, getMTLBufferStorage zero-copy Default — forward, training, masks, dropout
mpsgraph pyobjc + CPU-copy bridge Fallback when the ext can't build
stock torch.nn.functional.scaled_dot_product_attention Always available final fallback
metal_proto naive single-thread Metal kernel via torch.mps.compile_shader Reference / experimentation, not auto-selected

Fallbacks are transparent — a call that can't go through mpsgraph_zc (e.g. short seq, CPU device, unsupported dtype) routes down the list automatically.

Correctness contract

Same tolerance bar as CUDA flash-attention vs math:

  • fp32: atol=5e-6, rtol=5e-5
  • fp16/bf16: atol=5e-3, rtol=5e-2
  • gradients: 2× the forward tolerance

Not bitwise identical. Checkpoints trained with one backend load cleanly into the other (verified).

Scope / supported configs

Category Maintainer-tested Should work (untested by maintainer) Not supported
Apple silicon M4 mini, M3 Max M1, M2 (should work via auto-calibration)
macOS 26.x 15.x (all API surfaces present) 14.x (op missing — backend registers unavailable)
torch 2.11 stable + 2.13 nightly
dtypes bf16, fp16, fp32 fp64 (MPS doesn't support it)

Maintainer testing is M-series Apple silicon on macOS 26+ only. Reports from other configurations are welcome but no commitment to test them.

Architecture

See ARCHITECTURE.md for the internal structure: backend registry, graph cache, threshold auto-calibration, C++ extension build system, autograd wiring, dropout path.

What doesn't work (yet)

  • GQA (Hq ≠ Hkv): routes to stock with repeat_interleave. mpsgraph op is MHA-only. One-time warning.
  • Second-order gradients (create_graph=True on our output's grad): raises a clear error. Backward uses MPSGraph which is opaque to torch autograd; true higher-order would require a differentiable backward impl.
  • macOS 14: the MPSGraph.scaledDotProductAttention method doesn't exist on Sequoia's predecessor. Backend registers as unavailable with a clear reason; calls fall back to stock.

What works

  • torch.compile: supported via torch.library.custom_op + register_autograd + register_fake (added in v0.2.0). sdpa_opt traces cleanly under torch.compile(..., dynamic=False); output is numerically identical to eager mode within cross_impl_atol(dtype). Six dedicated equivalence tests in tests/test_torch_compile.py cover forward, backward, masked, dropout-entropy, and opcheck smoke.

Correctness — what's tested

254 tests across 43 files. Highlights:

  • Shape matrix: D ∈ {32, 64, 96, 128, 192, 256}; H ∈ {1..32}; B ∈ {1..32}; Lq, Lkv ∈ {powers of 2, 777, 1345, 3141}.
  • Masks: causal, bool, additive float, all-True, mostly-False, per-head [B,H,Lq,Lkv], per-batch, broadcast variants, dtype coercion.
  • Autograd: partial requires_grad subsets, AMP autocast (bf16/fp16), torch.utils.checkpoint re-entry, retain_graph, grad accumulation, second-order detection, once_differentiable guard.
  • Training: generic transformer smoke (forward + backward), 1000-step long-horizon convergence, mid-run backend toggle, fp16 500-step AMP.
  • Numerical extremes: Q/K std ∈ {10, 1e-4}, grad_out ∈ {1e3, 1e4} — all NaN/Inf-free.
  • Edge cases: degenerate shapes, non-contiguous inputs, strided views, cache thrash (9 shapes), 200-call leak probe, 16k-seq OOM recovery.

See tests/ for the full test suite.

Citation / credit

  • Apple's fused SDPA op: MPSGraph.scaledDotProductAttention... (macOS 15+).
  • Related prior art: philipturner/metal-flash-attention (Swift, not vetted here); PyPI mps-flash-attn (8 stars, unverified). Neither is used or wrapped.

License

MIT.

Contributing

See CONTRIBUTING.md. Bug reports with reproducible shapes are the most helpful; the full-suite pytest tests/ should pass on any supported config.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_sdpa-0.2.0.tar.gz (81.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_sdpa-0.2.0-py3-none-any.whl (58.2 kB view details)

Uploaded Python 3

File details

Details for the file mps_sdpa-0.2.0.tar.gz.

File metadata

  • Download URL: mps_sdpa-0.2.0.tar.gz
  • Upload date:
  • Size: 81.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mps_sdpa-0.2.0.tar.gz
Algorithm Hash digest
SHA256 1d6b6024801657bbdc4a6420ec7a0b52191ab0b88d847143a3ea1728ba3f93b5
MD5 eedcede0444d15b1729eec558c7c3348
BLAKE2b-256 f0c2cf52e4364c4a5f9db84238a34974a9498e5922720349ec6f6a3700cd492c

See more details on using hashes here.

Provenance

The following attestation bundles were made for mps_sdpa-0.2.0.tar.gz:

Publisher: pypi.yml on crlandsc/mps-sdpa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mps_sdpa-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mps_sdpa-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 58.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mps_sdpa-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2fa980c7028968164a6ae15e4ed380447c76411f4a258184f23d80cf5272f3fd
MD5 0d8613376f0c7eccc21a49d1be60d3a7
BLAKE2b-256 b1eb629b9846b972471c4edacb244e1f59d72862fa57b7bc92a8c7fc5dabd9b0

See more details on using hashes here.

Provenance

The following attestation bundles were made for mps_sdpa-0.2.0-py3-none-any.whl:

Publisher: pypi.yml on crlandsc/mps-sdpa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page