Skip to main content

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)

Project description

MPS Flash Attention

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).

O(N) memory instead of O(N²), enabling 100K+ sequence lengths on unified memory.

Performance

Benchmarked on Apple Silicon (M1/M2/M3/M4):

Seq Length vs PyTorch SDPA Notes
1024 1.1-2.0x faster Crossover point
2048 1.7-3.7x faster Sweet spot
4096 2.0-3.9x faster Peak performance
8192+ 3-4x faster SDPA often OOMs

Average speedup: 1.8x across all configurations.

Installation

pip install mps-flash-attn

Build from source

git clone --recursive https://github.com/mpsops/mps-flash-attention.git
cd mps-flash-attention

# Build Swift bridge
cd swift-bridge && swift build -c release && cd ..

# Install
pip install -e .

# Set bridge path
export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib

Usage

Basic Attention

from mps_flash_attn import flash_attention

# (B, H, N, D) format
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)

out = flash_attention(q, k, v)

Causal Masking

out = flash_attention(q, k, v, is_causal=True)

Sliding Window (Mistral/Llama 3.2)

# Only attend to last 4096 tokens
out = flash_attention(q, k, v, is_causal=True, window_size=4096)

Quantized KV Cache (2-4x memory savings)

from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8

# Quantize K/V to FP8
k_quant, k_scale = quantize_kv_fp8(k)
v_quant, v_scale = quantize_kv_fp8(v)

# Run attention with quantized KV
out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)

100K+ Long Sequences

from mps_flash_attn import flash_attention_chunked

# Process 100K tokens without OOM
q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)

out = flash_attention_chunked(q, k, v, chunk_size=8192)

Drop-in SDPA Replacement

from mps_flash_attn import replace_sdpa

replace_sdpa()  # Patches F.scaled_dot_product_attention

# Now all PyTorch attention uses Flash Attention on MPS

torch.compile() Support

from mps_flash_attn import register_custom_op

register_custom_op()

@torch.compile
def my_attention(q, k, v):
    return torch.ops.mfa.flash_attention(q, k, v, False, None, None)

Training with BF16 Backward

out = flash_attention(q, k, v, bf16_backward=True)  # 2x faster backward
loss = out.sum()
loss.backward()

Benchmarking

# Quick benchmark
python -m mps_flash_attn.benchmark --suite quick

# Full suite with report
python -m mps_flash_attn.benchmark --suite full --output report.html
from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa

results = run_suite(seq_lengths=[1024, 2048, 4096])
compare_vs_sdpa()

Features

Feature Status Notes
Forward pass FP16/BF16/FP32
Backward pass Full gradient support
Causal masking Native kernel support
Attention masks Boolean masks
Sliding window For local attention models
GQA/MQA Grouped-query attention
Quantized KV FP8, INT8, NF4
Chunked attention 100K+ tokens
torch.compile() Custom op backend
Dropout Not supported

Architecture

Python API (mps_flash_attn)
         │
    C++ Extension (mps_flash_attn.mm)
         │ dlopen
    Swift Bridge (MFABridge.swift)
         │
    Metal Flash Attention (kernel generation)
         │
    Metal GPU Shaders

Requirements

  • macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
  • Apple Silicon (M1/M2/M3/M4)
  • Python 3.10+
  • PyTorch 2.0+

TODO / Future Optimizations

  • Batched kernel dispatch - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
  • Fused QKV projection + attention - Single kernel from input to output, avoid intermediate buffers
  • Pre-scaled bias option - Allow passing pre-scaled bias to avoid per-call scaling overhead
  • LoRA fusion - Fuse adapter weights into attention computation

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

mps_flash_attn-0.6.0-cp314-cp314-macosx_11_0_arm64.whl (968.2 kB view details)

Uploaded CPython 3.14macOS 11.0+ ARM64

mps_flash_attn-0.6.0-cp313-cp313-macosx_11_0_arm64.whl (958.4 kB view details)

Uploaded CPython 3.13macOS 11.0+ ARM64

mps_flash_attn-0.6.0-cp312-cp312-macosx_11_0_arm64.whl (958.2 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

mps_flash_attn-0.6.0-cp311-cp311-macosx_11_0_arm64.whl (962.4 kB view details)

Uploaded CPython 3.11macOS 11.0+ ARM64

mps_flash_attn-0.6.0-cp310-cp310-macosx_11_0_arm64.whl (956.4 kB view details)

Uploaded CPython 3.10macOS 11.0+ ARM64

File details

Details for the file mps_flash_attn-0.6.0-cp314-cp314-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.6.0-cp314-cp314-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 8b9d04ea528af235bcc125466e2ab254f1ae0fe8376881c52ac9103b3ec478ec
MD5 50047472dd333aa70f625a2680a4c088
BLAKE2b-256 2be9d8733609ec00b824cbed9bbfaccb61cbbfffe355dea832d8b541d001d2d2

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.6.0-cp313-cp313-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.6.0-cp313-cp313-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 01ecdfc4f1bf253faa1f4c361711553fb9923e3640f4945a3195e7119ef0f091
MD5 016faf8d7c8aa6f167d7bf468c871d72
BLAKE2b-256 ffb5e2702624fc0ee54445910f28d8f8a96c8e9268c3d4cd07fe93fc38160f57

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.6.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.6.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 40fa9e00d6349a747eb9f38af8d3bd4288c9b34d922e44baa4c430f07601cfd0
MD5 a691a7d43165c4720826fba4007d3a3a
BLAKE2b-256 8cd2e78f4e46e188174d73056458521eef584a27ee78a1d9d2174a9b3589bab6

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.6.0-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.6.0-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 0a2fc28be3ac34807e797154c2371673e93ef651256447907a9f1e16a3edb05c
MD5 41b871c4d5d9f85161244c649b0403a5
BLAKE2b-256 b0fc690cf9405a5e142b30d304b6057503a66c818a6d4e55a9f21a400546dc7a

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.6.0-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.6.0-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 0723b50e3af9faa219714343d61b70ea86dc75a103f779a9435622943a10bdad
MD5 6e3fe9895955591e2ae533036d2ce1d9
BLAKE2b-256 d4924d12deaf1ad52610c82989c88418df9f0dc59730932877f6abaf7936edc0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page