Skip to main content

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)

Project description

MPS Flash Attention

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).

O(N) memory instead of O(N²), enabling 8K+ sequence lengths on unified memory.

Features

  • Forward pass: 2-5x faster than PyTorch SDPA
  • Backward pass: Full gradient support for training (fp32 precision)
  • Causal masking: Native kernel support (only 5% overhead)
  • Attention masks: Full boolean mask support for arbitrary masking patterns
  • FP16/FP32: Native fp16 output (no conversion overhead)
  • Pre-compiled kernels: Zero-compilation cold start (~6ms)

Performance

Tested on M1 Max, N=2048, B=4, H=8, D=64:

Operation MPS Flash Attn PyTorch SDPA Speedup
Forward 5.3ms 15ms 2.8x
Forward+Backward 55ms 108ms 2.0x
Memory 80MB 592MB 7.4x less

Installation

Prerequisites

  • macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
  • Xcode Command Line Tools (xcode-select --install)
  • Python 3.10+ with PyTorch 2.0+

Build from source

# Clone with submodules
git clone --recursive https://github.com/mpsops/mps-flash-attention.git
cd mps-flash-attention

# Build Swift bridge
cd swift-bridge
swift build -c release
cd ..

# Install Python package
pip install -e .

Set environment variable

export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib

Usage

Basic usage

from mps_flash_attn import flash_attention

# Standard attention (B, H, N, D)
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)

out = flash_attention(q, k, v)

Causal masking (for autoregressive models)

out = flash_attention(q, k, v, is_causal=True)

Attention masks (for custom masking patterns)

# Boolean mask: True = masked (don't attend), False = attend
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
mask[:, :, :, 512:] = True  # Mask out positions after 512

out = flash_attention(q, k, v, attn_mask=mask)

Training with gradients

q.requires_grad = True
k.requires_grad = True
v.requires_grad = True

out = flash_attention(q, k, v, is_causal=True)
loss = out.sum()
loss.backward()  # Computes dQ, dK, dV

Drop-in replacement for SDPA

from mps_flash_attn import replace_sdpa

# Monkey-patch F.scaled_dot_product_attention
replace_sdpa()

# Now all attention ops use Flash Attention on MPS

Architecture

+----------------------------------------------------------+
|                    Python API                            |
|              mps_flash_attn/__init__.py                  |
|         (flash_attention, autograd Function)             |
+----------------------------+-----------------------------+
                             |
+----------------------------v-----------------------------+
|                 C++ Extension                            |
|            mps_flash_attn/csrc/mps_flash_attn.mm         |
|    (PyTorch bindings, MTLBuffer handling, offsets)       |
+----------------------------+-----------------------------+
                             | dlopen + dlsym
+----------------------------v-----------------------------+
|                 Swift Bridge                             |
|         swift-bridge/Sources/MFABridge/                  |
|   (MFABridge.swift, MetallibCache.swift)                 |
|   @_cdecl exports: mfa_init, mfa_create_kernel,          |
|                    mfa_forward, mfa_backward             |
+----------------------------+-----------------------------+
                             |
+----------------------------v-----------------------------+
|              Metal Flash Attention                       |
|    metal-flash-attention/Sources/FlashAttention/         |
|     (AttentionDescriptor, AttentionKernel, etc.)         |
|                                                          |
|   Generates Metal shader source at runtime,              |
|   compiles to .metallib, caches pipelines                |
+----------------------------------------------------------+

Project Structure

mps-flash-attention/
├── mps_flash_attn/              # Python package
│   ├── __init__.py              # Public API (flash_attention, replace_sdpa)
│   ├── csrc/
│   │   └── mps_flash_attn.mm    # PyTorch C++ extension
│   └── kernels/                 # Pre-compiled metallibs (optional)
│
├── swift-bridge/                # Swift -> C bridge
│   ├── Package.swift
│   └── Sources/MFABridge/
│       ├── MFABridge.swift      # C-callable API (@_cdecl)
│       └── MetallibCache.swift  # Disk caching for metallibs
│
├── metal-flash-attention/       # Upstream (git submodule)
│   └── Sources/FlashAttention/
│       └── Attention/
│           ├── AttentionDescriptor/  # Problem configuration
│           ├── AttentionKernel/      # Metal shader generation
│           └── ...
│
├── scripts/
│   └── build_metallibs.py       # Pre-compile kernels for distribution
│
└── setup.py                     # Python package setup

Changes from upstream metal-flash-attention

We made the following modifications to metal-flash-attention:

1. macOS 15+ compatibility (MTLLibraryCompiler.swift)

Apple restricted __asm in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses xcrun metal CLI compilation when runtime compilation fails.

2. Causal masking support

Added causal flag to AttentionDescriptor and kernel generation:

  • AttentionDescriptor.swift: Added causal: Bool property
  • AttentionKernelDescriptor.swift: Added causal: Bool property
  • AttentionKernel.swift: Added causal field
  • AttentionKernel+Softmax.swift: Added maskCausal() function
  • AttentionKernel+Source.swift: Added causal masking to forward/backward loops

Next Steps

1. PR to upstream metal-flash-attention

The macOS 15 fix and causal masking should be contributed back:

cd metal-flash-attention
git checkout -b macos15-causal-support
# Commit changes to:
#   - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
#   - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
#   - Sources/FlashAttention/Attention/AttentionKernel/*.swift
git push origin macos15-causal-support
# Open PR at https://github.com/philipturner/metal-flash-attention

2. Publish mps-flash-attention to PyPI

# Add pyproject.toml with proper metadata
# Build wheel with pre-compiled Swift bridge
python -m build
twine upload dist/*

3. Pre-compile kernels for zero cold start

python scripts/build_metallibs.py
# Copies metallibs to mps_flash_attn/kernels/
# These get shipped with the wheel

Current Status (Jan 2025)

Working:

  • Forward pass (fp16/fp32)
  • Backward pass (dQ, dK, dV gradients)
  • Causal masking
  • Metallib disk caching
  • Pipeline binary caching (MTLBinaryArchive)

Known limitations:

  • Sequence length must be divisible by block size (typically 64)
  • Head dimension: Best with 32, 64, 96, 128
  • No dropout

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_flash_attn-0.1.7.tar.gz (174.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_flash_attn-0.1.7-cp314-cp314-macosx_15_0_arm64.whl (366.4 kB view details)

Uploaded CPython 3.14macOS 15.0+ ARM64

File details

Details for the file mps_flash_attn-0.1.7.tar.gz.

File metadata

  • Download URL: mps_flash_attn-0.1.7.tar.gz
  • Upload date:
  • Size: 174.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_flash_attn-0.1.7.tar.gz
Algorithm Hash digest
SHA256 f33e1c4af7716ce1a5dc349e85177b15420e4a1d7ad28cf80ff32b0a64e1ef1b
MD5 8f083f5aa4bd20d48af056515b93af4d
BLAKE2b-256 02f388bee3fac2668df5e15896458f43cd4880e85a6c121c38201e494287b6e7

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.1.7-cp314-cp314-macosx_15_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.1.7-cp314-cp314-macosx_15_0_arm64.whl
Algorithm Hash digest
SHA256 68cf1e3d3527bfeae92574591e8b9e41cfead84f58c18ec726dbeae6355b2242
MD5 b07b3744cd2e6ab027dc02e7a83dedfb
BLAKE2b-256 0515996d6cc508aea774f73b2c6f09a19f0ce7f74253b4dfbf18b513f25ba68d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page