Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
Project description
MPS Flash Attention
Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
O(N) memory instead of O(N²), enabling 8K+ sequence lengths on unified memory.
Features
- Forward pass: 2-5x faster than PyTorch SDPA
- Backward pass: Full gradient support for training
- Causal masking: Native kernel support (only 5% overhead)
- FP16/FP32: Native fp16 output (no conversion overhead)
- Pre-compiled kernels: Zero-compilation cold start (~6ms)
Performance
Tested on M1 Max, N=2048, B=4, H=8, D=64:
| Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
|---|---|---|---|
| Forward | 5.3ms | 15ms | 2.8x |
| Forward+Backward | 55ms | 108ms | 2.0x |
| Memory | 80MB | 592MB | 7.4x less |
Installation
Prerequisites
- macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
- Xcode Command Line Tools (
xcode-select --install) - Python 3.10+ with PyTorch 2.0+
Build from source
# Clone with submodules
git clone --recursive https://github.com/mpsops/mps-flash-attention.git
cd mps-flash-attention
# Build Swift bridge
cd swift-bridge
swift build -c release
cd ..
# Install Python package
pip install -e .
Set environment variable
export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
Usage
Basic usage
from mps_flash_attn import flash_attention
# Standard attention (B, H, N, D)
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
out = flash_attention(q, k, v)
Causal masking (for autoregressive models)
out = flash_attention(q, k, v, is_causal=True)
Training with gradients
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
out = flash_attention(q, k, v, is_causal=True)
loss = out.sum()
loss.backward() # Computes dQ, dK, dV
Drop-in replacement for SDPA
from mps_flash_attn import replace_sdpa
# Monkey-patch F.scaled_dot_product_attention
replace_sdpa()
# Now all attention ops use Flash Attention on MPS
Architecture
+----------------------------------------------------------+
| Python API |
| mps_flash_attn/__init__.py |
| (flash_attention, autograd Function) |
+----------------------------+-----------------------------+
|
+----------------------------v-----------------------------+
| C++ Extension |
| mps_flash_attn/csrc/mps_flash_attn.mm |
| (PyTorch bindings, MTLBuffer handling, offsets) |
+----------------------------+-----------------------------+
| dlopen + dlsym
+----------------------------v-----------------------------+
| Swift Bridge |
| swift-bridge/Sources/MFABridge/ |
| (MFABridge.swift, MetallibCache.swift) |
| @_cdecl exports: mfa_init, mfa_create_kernel, |
| mfa_forward, mfa_backward |
+----------------------------+-----------------------------+
|
+----------------------------v-----------------------------+
| Metal Flash Attention |
| metal-flash-attention/Sources/FlashAttention/ |
| (AttentionDescriptor, AttentionKernel, etc.) |
| |
| Generates Metal shader source at runtime, |
| compiles to .metallib, caches pipelines |
+----------------------------------------------------------+
Project Structure
mps-flash-attention/
├── mps_flash_attn/ # Python package
│ ├── __init__.py # Public API (flash_attention, replace_sdpa)
│ ├── csrc/
│ │ └── mps_flash_attn.mm # PyTorch C++ extension
│ └── kernels/ # Pre-compiled metallibs (optional)
│
├── swift-bridge/ # Swift -> C bridge
│ ├── Package.swift
│ └── Sources/MFABridge/
│ ├── MFABridge.swift # C-callable API (@_cdecl)
│ └── MetallibCache.swift # Disk caching for metallibs
│
├── metal-flash-attention/ # Upstream (git submodule)
│ └── Sources/FlashAttention/
│ └── Attention/
│ ├── AttentionDescriptor/ # Problem configuration
│ ├── AttentionKernel/ # Metal shader generation
│ └── ...
│
├── scripts/
│ └── build_metallibs.py # Pre-compile kernels for distribution
│
└── setup.py # Python package setup
Changes from upstream metal-flash-attention
We made the following modifications to metal-flash-attention:
1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
Apple restricted __asm in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses xcrun metal CLI compilation when runtime compilation fails.
2. Causal masking support
Added causal flag to AttentionDescriptor and kernel generation:
AttentionDescriptor.swift: Addedcausal: BoolpropertyAttentionKernelDescriptor.swift: Addedcausal: BoolpropertyAttentionKernel.swift: AddedcausalfieldAttentionKernel+Softmax.swift: AddedmaskCausal()functionAttentionKernel+Source.swift: Added causal masking to forward/backward loops
Next Steps
1. PR to upstream metal-flash-attention
The macOS 15 fix and causal masking should be contributed back:
cd metal-flash-attention
git checkout -b macos15-causal-support
# Commit changes to:
# - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
# - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
# - Sources/FlashAttention/Attention/AttentionKernel/*.swift
git push origin macos15-causal-support
# Open PR at https://github.com/philipturner/metal-flash-attention
2. Publish mps-flash-attention to PyPI
# Add pyproject.toml with proper metadata
# Build wheel with pre-compiled Swift bridge
python -m build
twine upload dist/*
3. Pre-compile kernels for zero cold start
python scripts/build_metallibs.py
# Copies metallibs to mps_flash_attn/kernels/
# These get shipped with the wheel
Current Status (Jan 2025)
Working:
- Forward pass (fp16/fp32)
- Backward pass (dQ, dK, dV gradients)
- Causal masking
- Metallib disk caching
- Pipeline binary caching (MTLBinaryArchive)
Known limitations:
- Sequence length must be divisible by block size (typically 64)
- Head dimension: Best with 32, 64, 96, 128
- No arbitrary attention masks (only causal or none)
- No dropout
Credits
- metal-flash-attention by Philip Turner
- Flash Attention paper by Tri Dao et al.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mps_flash_attn-0.1.5.tar.gz.
File metadata
- Download URL: mps_flash_attn-0.1.5.tar.gz
- Upload date:
- Size: 337.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2af0bd0d08b5f6e905bfd4a264cd22bf71c4fc934c514981d807e9506598abc8
|
|
| MD5 |
1f276437eecbcd56a2a73e1bfd702891
|
|
| BLAKE2b-256 |
19351ae03ea72918483ab0fc84aaad837ef73650083389af037d6902043eca4b
|
File details
Details for the file mps_flash_attn-0.1.5-cp314-cp314-macosx_15_0_arm64.whl.
File metadata
- Download URL: mps_flash_attn-0.1.5-cp314-cp314-macosx_15_0_arm64.whl
- Upload date:
- Size: 525.0 kB
- Tags: CPython 3.14, macOS 15.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
632fb68399e2780bb8433cd8e439307a54b67bde51332f5511ff5c485fecadcc
|
|
| MD5 |
e2da15a4aee6d8fe0196e856f9b4d9da
|
|
| BLAKE2b-256 |
8f2ce0582ac39035305ff29b7ee93b6fa6f7083448535602752a107354d127c6
|