Skip to main content

Fast drop-in scaled_dot_product_attention for Apple silicon (MPS) — wraps Apple's native MPSGraph SDPA with a zero-copy C++ bridge. 5-7x inference, 2-2.5x training.

Project description

mps-sdpa

Fast drop-in scaled_dot_product_attention for Apple silicon (MPS). Wraps Apple's native MPSGraph.scaledDotProductAttention op with a zero-copy C++ / Objective-C++ bridge. 5–7× inference speedup, 2–2.5× training speedup, 16–170× less driver memory per call — with identical math and checkpoint compatibility.

from mps_sdpa import sdpa_opt

# Drop-in replacement — same signature as torch.nn.functional.scaled_dot_product_attention
out = sdpa_opt(query, key, value,
               attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
               backend="auto")

Why

PyTorch's MPS backend dispatches torch.nn.functional.scaled_dot_product_attention to sdpa_general_mps, which builds a naive matmul → softmax → matmul graph. It does not call Apple's dedicated MPSGraph.scaledDotProductAttention op, which is present on macOS 15+ and is significantly faster.

mps-sdpa wraps that op directly, with an autograd.Function for training, shape-threshold auto-calibration, graceful fallbacks, and thread-safe graph caching. The C++ extension uses ATen's getMTLBufferStorage to hand torch tensors to MPSGraph without CPU memcpy.

Measured performance

All numbers on M4 / macOS 26.4.1 / torch 2.13 nightly, bfloat16.

Inference (forward only, B=1, H=8, D=64)

L stock mps-sdpa speedup
1024 5.78 ms 0.90 ms 6.42×
2048 19.0 ms 3.82 ms 4.97×
4096 76.1 ms 11.79 ms 6.45×
8192 317 ms 44.3 ms 7.17×

Weighted geomean across a realistic audio-model shape suite: 4.88×.

Training (forward + backward, same shapes)

L stock mps-sdpa speedup
1024 9.93 ms 5.06 ms 1.96×
2048 38.6 ms 17.1 ms 2.25×
4096 154 ms 64.8 ms 2.38×
8192 608 ms 247 ms 2.46×

Training with dropout (dropout_p=0.1)

L stock mps-sdpa speedup
1024 14.19 ms 7.63 ms 1.86×
2048 55.83 ms 28.49 ms 1.96×
4096 228 ms 101 ms 2.26×

Driver memory per call

Apple's fused op doesn't materialize the [Lq, Lkv] attention matrix. The zero-copy C++ bridge removes the CPU-side intermediate buffer too.

L stock mps-sdpa reduction
2048 1024 MB <1 MB ≫128×
4096 1024 MB <1 MB ≫64×
8192 1024 MB 32 MB 32×

Install

pip install mps-sdpa

Requires macOS 15+ on Apple silicon (M1–M4) with PyTorch ≥ 2.11. ninja is pulled in automatically for the zero-copy backend's JIT compile (~6s on first import, then cached to ~/.cache/torch_extensions).

Development install from source:

git clone https://github.com/crlandsc/mps-sdpa.git
cd mps-sdpa
pip install -e ".[dev]"
pytest tests/

Usage

The API exactly mirrors torch.nn.functional.scaled_dot_product_attention:

import torch
from mps_sdpa import sdpa_opt

q = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
k = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")
v = torch.randn(1, 8, 2048, 64, dtype=torch.bfloat16, device="mps")

out = sdpa_opt(q, k, v)                                  # basic
out = sdpa_opt(q, k, v, is_causal=True)                  # causal mask
out = sdpa_opt(q, k, v, attn_mask=bool_mask)             # bool mask
out = sdpa_opt(q, k, v, dropout_p=0.1)                   # training dropout
out = sdpa_opt(q, k, v, backend="mpsgraph_zc")           # force specific backend

As a drop-in swap inside an existing model:

import torch.nn as nn
import torch.nn.functional as F
from mps_sdpa import sdpa_opt

class MyAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.dropout_p = 0.1

    def forward(self, x):
        B, L, _ = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        p = self.dropout_p if self.training else 0.0  # standard pattern
        y = sdpa_opt(q, k, v, dropout_p=p)            # <— was F.scaled_dot_product_attention
        y = y.transpose(1, 2).reshape(B, L, self.d_model)
        return self.out(y)

Gate the behavior at model-construction time (useful for A/B comparison):

class Attn(nn.Module):
    def __init__(self, ..., use_mps_sdpa_opt: bool = False):
        self.use_opt = use_mps_sdpa_opt

    def forward(self, q, k, v):
        if self.use_opt:
            return sdpa_opt(q, k, v, ...)
        return F.scaled_dot_product_attention(q, k, v, ...)

Checkpoints are fully interchangeable — use_opt changes behavior only, not parameters.

CLI

mps-sdpa self-test                      # quick validation (<1s)
mps-sdpa list-backends                  # show available backends
mps-sdpa correctness --backend mpsgraph_zc --device mps --suite realistic
mps-sdpa benchmark --backend mpsgraph_zc --baseline stock --device mps --suite realistic

Backends

backend="auto" picks the best available. Available backends (in preference order):

Name Impl Best for
mpsgraph_zc Obj-C++ torch extension, getMTLBufferStorage zero-copy Default — forward, training, masks, dropout
mpsgraph pyobjc + CPU-copy bridge Fallback when the ext can't build
stock torch.nn.functional.scaled_dot_product_attention Always available final fallback
metal_proto naive single-thread Metal kernel via torch.mps.compile_shader Reference / experimentation, not auto-selected

Fallbacks are transparent — a call that can't go through mpsgraph_zc (e.g. short seq, CPU device, unsupported dtype) routes down the list automatically.

Correctness contract

Same tolerance bar as CUDA flash-attention vs math:

  • fp32: atol=5e-6, rtol=5e-5
  • fp16/bf16: atol=5e-3, rtol=5e-2
  • gradients: 2× the forward tolerance

Not bitwise identical. Checkpoints trained with one backend load cleanly into the other (verified).

Scope / supported configs

Category Tested Planned Not planned
Apple silicon M4 M3 M1, M2 (should work via auto-calibration)
macOS 26.4.1 15.x 14.x (documented unsupported — backend registers unavailable)
torch 2.11 stable + 2.13 nightly
dtypes bf16, fp16, fp32 fp64 (MPS doesn't support it)

Architecture

See ARCHITECTURE.md for the internal structure: backend registry, graph cache, threshold auto-calibration, C++ extension build system, autograd wiring, dropout path.

What doesn't work (yet)

  • GQA (Hq ≠ Hkv): routes to stock with repeat_interleave. mpsgraph op is MHA-only. One-time warning.
  • Second-order gradients (create_graph=True on our output's grad): raises a clear error. Backward uses MPSGraph which is opaque to torch autograd; true higher-order would require a differentiable backward impl.
  • torch.compile (full graph capture): untested. The JIT-compiled C++ extension is compatible with eager mode; inductor integration is future work.
  • macOS 14: the MPSGraph.scaledDotProductAttention method doesn't exist on Sequoia's predecessor. Backend registers as unavailable with a clear reason; calls fall back to stock.

Correctness — what's tested

244 tests across 20 files. Highlights:

  • Shape matrix: D ∈ {32, 64, 96, 128, 192, 256}; H ∈ {1..32}; B ∈ {1..32}; Lq, Lkv ∈ {powers of 2, 777, 1345, 3141}.
  • Masks: causal, bool, additive float, all-True, mostly-False, per-head [B,H,Lq,Lkv], per-batch, broadcast variants, dtype coercion.
  • Autograd: partial requires_grad subsets, AMP autocast (bf16/fp16), torch.utils.checkpoint re-entry, retain_graph, grad accumulation, second-order detection, once_differentiable guard.
  • Training: generic transformer smoke (forward + backward), 1000-step long-horizon convergence, mid-run backend toggle, fp16 500-step AMP.
  • Numerical extremes: Q/K std ∈ {10, 1e-4}, grad_out ∈ {1e3, 1e4} — all NaN/Inf-free.
  • Edge cases: degenerate shapes, non-contiguous inputs, strided views, cache thrash (9 shapes), 200-call leak probe, 16k-seq OOM recovery.

See tests/ for the full test suite.

Citation / credit

  • Apple's fused SDPA op: MPSGraph.scaledDotProductAttention... (macOS 15+).
  • Related prior art: philipturner/metal-flash-attention (Swift, not vetted here); PyPI mps-flash-attn (8 stars, unverified). Neither is used or wrapped.

License

MIT.

Contributing

See CONTRIBUTING.md. Bug reports with reproducible shapes are the most helpful; the full-suite pytest tests/ should pass on any supported config.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_sdpa-0.1.0.tar.gz (70.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_sdpa-0.1.0-py3-none-any.whl (54.0 kB view details)

Uploaded Python 3

File details

Details for the file mps_sdpa-0.1.0.tar.gz.

File metadata

  • Download URL: mps_sdpa-0.1.0.tar.gz
  • Upload date:
  • Size: 70.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mps_sdpa-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f92da9f41796a1ca31110630943c8b494b74da28edd36fcf74903a6af2c46980
MD5 16dc54604063b446e0ba3fbaad24370f
BLAKE2b-256 39afa9a5195db099bd5c7359ed65380f5910e3708019e472bc64ac39425b12f9

See more details on using hashes here.

Provenance

The following attestation bundles were made for mps_sdpa-0.1.0.tar.gz:

Publisher: pypi.yml on crlandsc/mps-sdpa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mps_sdpa-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mps_sdpa-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 54.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mps_sdpa-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 82bda73eac6fc7f72df16b9d3d390ec60a5199be80d65bf590fdcb6da5e32b11
MD5 c8cb6b8f7c0bb97db566ae23fb10bd05
BLAKE2b-256 b97ffadb504a1c916a8a2eaa7cebfdecd9267dd4d5c3a49fde991e4a9b11feb9

See more details on using hashes here.

Provenance

The following attestation bundles were made for mps_sdpa-0.1.0-py3-none-any.whl:

Publisher: pypi.yml on crlandsc/mps-sdpa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page