Skip to main content

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)

Project description

MPS Flash Attention

Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).

O(N) memory instead of O(N²), enabling 100K+ sequence lengths on unified memory.

Performance

Benchmarked on Apple Silicon (M1/M2/M3/M4):

Seq Length vs PyTorch SDPA Notes
1024 1.1-2.0x faster Crossover point
2048 1.7-3.7x faster Sweet spot
4096 2.0-3.9x faster Peak performance
8192+ 3-4x faster SDPA often OOMs

Average speedup: 1.8x across all configurations.

Installation

pip install mps-flash-attn

Build from source

git clone --recursive https://github.com/mpsops/mps-flash-attention.git
cd mps-flash-attention

# Build Swift bridge
cd swift-bridge && swift build -c release && cd ..

# Install
pip install -e .

# Set bridge path
export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib

Usage

Basic Attention

from mps_flash_attn import flash_attention

# (B, H, N, D) format
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)

out = flash_attention(q, k, v)

Causal Masking

out = flash_attention(q, k, v, is_causal=True)

Sliding Window (Mistral/Llama 3.2)

# Only attend to last 4096 tokens
out = flash_attention(q, k, v, is_causal=True, window_size=4096)

Quantized KV Cache (2-4x memory savings)

from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8

# Quantize K/V to FP8
k_quant, k_scale = quantize_kv_fp8(k)
v_quant, v_scale = quantize_kv_fp8(v)

# Run attention with quantized KV
out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)

100K+ Long Sequences

from mps_flash_attn import flash_attention_chunked

# Process 100K tokens without OOM
q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)

out = flash_attention_chunked(q, k, v, chunk_size=8192)

Drop-in SDPA Replacement

from mps_flash_attn import replace_sdpa

replace_sdpa()  # Patches F.scaled_dot_product_attention

# Now all PyTorch attention uses Flash Attention on MPS

torch.compile() Support

from mps_flash_attn import register_custom_op

register_custom_op()

@torch.compile
def my_attention(q, k, v):
    return torch.ops.mfa.flash_attention(q, k, v, False, None, None)

Training with BF16 Backward

out = flash_attention(q, k, v, bf16_backward=True)  # 2x faster backward
loss = out.sum()
loss.backward()

Benchmarking

# Quick benchmark
python -m mps_flash_attn.benchmark --suite quick

# Full suite with report
python -m mps_flash_attn.benchmark --suite full --output report.html
from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa

results = run_suite(seq_lengths=[1024, 2048, 4096])
compare_vs_sdpa()

Features

Feature Status Notes
Forward pass FP16/BF16/FP32
Backward pass Full gradient support
Causal masking Native kernel support
Attention masks Boolean masks
Sliding window For local attention models
GQA/MQA Grouped-query attention
Quantized KV FP8, INT8, NF4
Chunked attention 100K+ tokens
torch.compile() Custom op backend
Dropout Not supported

Architecture

Python API (mps_flash_attn)
         │
    C++ Extension (mps_flash_attn.mm)
         │ dlopen
    Swift Bridge (MFABridge.swift)
         │
    Metal Flash Attention (kernel generation)
         │
    Metal GPU Shaders

Requirements

  • macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
  • Apple Silicon (M1/M2/M3/M4)
  • Python 3.10+
  • PyTorch 2.0+

TODO / Future Optimizations

  • Batched kernel dispatch - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
  • Fused QKV projection + attention - Single kernel from input to output, avoid intermediate buffers
  • Pre-scaled bias option - Allow passing pre-scaled bias to avoid per-call scaling overhead
  • LoRA fusion - Fuse adapter weights into attention computation

Credits

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mps_flash_attn-0.3.2.tar.gz (385.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mps_flash_attn-0.3.2-cp314-cp314-macosx_15_0_arm64.whl (585.7 kB view details)

Uploaded CPython 3.14macOS 15.0+ ARM64

File details

Details for the file mps_flash_attn-0.3.2.tar.gz.

File metadata

  • Download URL: mps_flash_attn-0.3.2.tar.gz
  • Upload date:
  • Size: 385.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for mps_flash_attn-0.3.2.tar.gz
Algorithm Hash digest
SHA256 64a9703bb5c9d223c37807ecc39375718f27def17b67efb102beb0557bda1ec0
MD5 3c29c6144c41dbe5f411d957de0694f1
BLAKE2b-256 ca183a9731e2586005ed4f44d8516a25c480df9ee7ce1054e6cba68a79ebb133

See more details on using hashes here.

File details

Details for the file mps_flash_attn-0.3.2-cp314-cp314-macosx_15_0_arm64.whl.

File metadata

File hashes

Hashes for mps_flash_attn-0.3.2-cp314-cp314-macosx_15_0_arm64.whl
Algorithm Hash digest
SHA256 715e1eb82d616dddcd08e110a38c2a63441326291be2aa445528277e70eac24a
MD5 488a8b46a170f0ba39e6779c3ab47db8
BLAKE2b-256 0e3dc968cba84d765bb58e35d4f13788cf78d81ed8120a4397f0f9f1b525eab4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page