Skip to main content

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)

Project description

MPS Flash Attention

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).

O(N) memory instead of O(N²), enabling 100K+ sequence lengths on unified memory.

Performance

Benchmarked on Apple Silicon (M1/M2/M3/M4):

Seq Length vs PyTorch SDPA Notes
1024 1.1-2.0x faster Crossover point
2048 1.7-3.7x faster Sweet spot
4096 2.0-3.9x faster Peak performance
8192+ 3-4x faster SDPA often OOMs

Average speedup: 1.8x across all configurations.

Installation

pip install mps-flash-attn

Build from source

git clone --recursive https://github.com/mpsops/mps-flash-attention.git
cd mps-flash-attention

# Build Swift bridge
cd swift-bridge && swift build -c release && cd ..

# Install
pip install -e .

# Set bridge path
export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib

Usage

Basic Attention

from mps_flash_attn import flash_attention

# (B, H, N, D) format
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)

out = flash_attention(q, k, v)

Causal Masking

out = flash_attention(q, k, v, is_causal=True)

Sliding Window (Mistral/Llama 3.2)

# Only attend to last 4096 tokens
out = flash_attention(q, k, v, is_causal=True, window_size=4096)

Quantized KV Cache (2-4x memory savings)

from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8

# Quantize K/V to FP8
k_quant, k_scale = quantize_kv_fp8(k)
v_quant, v_scale = quantize_kv_fp8(v)

# Run attention with quantized KV
out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)

100K+ Long Sequences

from mps_flash_attn import flash_attention_chunked

# Process 100K tokens without OOM
q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)

out = flash_attention_chunked(q, k, v, chunk_size=8192)

Drop-in SDPA Replacement

from mps_flash_attn import replace_sdpa

replace_sdpa()  # Patches F.scaled_dot_product_attention

# Now all PyTorch attention uses Flash Attention on MPS

torch.compile() Support

from mps_flash_attn import register_custom_op

register_custom_op()

@torch.compile
def my_attention(q, k, v):
    return torch.ops.mfa.flash_attention(q, k, v, False, None, None)

Training with BF16 Backward

out = flash_attention(q, k, v, bf16_backward=True)  # 2x faster backward
loss = out.sum()
loss.backward()

Benchmarking

# Quick benchmark
python -m mps_flash_attn.benchmark --suite quick

# Full suite with report
python -m mps_flash_attn.benchmark --suite full --output report.html
from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa

results = run_suite(seq_lengths=[1024, 2048, 4096])
compare_vs_sdpa()

Features

Feature Status Notes
Forward pass FP16/BF16/FP32
Backward pass Full gradient support
Causal masking Native kernel support
Attention masks Boolean masks
Sliding window For local attention models
GQA/MQA Grouped-query attention
Quantized KV FP8, INT8, NF4
Chunked attention 100K+ tokens
torch.compile() Custom op backend
Dropout Not supported

Architecture

Python API (mps_flash_attn)
         │
    C++ Extension (mps_flash_attn.mm)
         │ dlopen
    Swift Bridge (MFABridge.swift)
         │
    Metal Flash Attention (kernel generation)
         │
    Metal GPU Shaders

Requirements

  • macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
  • Apple Silicon (M1/M2/M3/M4)
  • Python 3.10+
  • PyTorch 2.0+

TODO / Future Optimizations

  • Batched kernel dispatch - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
  • Fused QKV projection + attention - Single kernel from input to output, avoid intermediate buffers
  • Pre-scaled bias option - Allow passing pre-scaled bias to avoid per-call scaling overhead
  • LoRA fusion - Fuse adapter weights into attention computation

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_flash_attn-0.5.1.tar.gz (660.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_flash_attn-0.5.1-cp314-cp314-macosx_26_0_arm64.whl (865.4 kB view details)

Uploaded CPython 3.14macOS 26.0+ ARM64

File details

Details for the file mps_flash_attn-0.5.1.tar.gz.

File metadata

  • Download URL: mps_flash_attn-0.5.1.tar.gz
  • Upload date:
  • Size: 660.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for mps_flash_attn-0.5.1.tar.gz
Algorithm Hash digest
SHA256 22b5fee9fc6b23987743f6e5a94f063bbdaaca1cf0e832c545572aa68356ccce
MD5 af2b5edc6e0d531ffed012ec8ebaf88c
BLAKE2b-256 941dcbc47df99cb2b6dba3159414d155eb35c77e49e924ec0e491e4878ffabf4

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.5.1-cp314-cp314-macosx_26_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.5.1-cp314-cp314-macosx_26_0_arm64.whl
Algorithm Hash digest
SHA256 6249952ffae8e9f3469cf53b113adedeac3bde6e0e881f8b8014503636d05b97
MD5 3b2c3d2ab939c660a19288b918acb515
BLAKE2b-256 d787f0b8efc47131aaa3020469ea4e45a1f81dcf854a24b53d294050340058b1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page