Skip to main content

FlashAttention-3 forward

Project description

Flash-Attention-3 Forward-Only Kernel

This repository bundles the Flash-Attention-3 forward-only kernel and the tooling required to build a lightweight Python wheel. It is intended for inference scenarios where backward operators and optional features are unnecessary.

Highlights

  • Ships only the Flash-Attention-3 forward path while disabling backward kernels, local attention, paged KV cache, FP16 kernels, and other extras to minimize the wheel size.
  • Applies a patch that renames the public interface to fa3_fwd_interface, making the forward kernel easy to import from Python.

Prerequisites(same as upstream)

  • Python: 3.9 or later
  • PyTorch: 2.9.0
  • Build dependencies: ninja, packaging, wheel

Quick Start

  1. Clone the repository and initialize submodules:

    git clone --recursive <repo-url>
    cd fa3-fwd
    # If --recursive was omitted during clone, run:
    git submodule update --init --recursive
    
  2. Create a Python virtual environment and install dependencies:

    uv venv --python 3.12 --seed
    source .venv/bin/activate
    uv pip install -r requirements.txt
    
  3. Build the forward-only wheel:

    bash build_fa3.sh
    

    The script:

    • Sources set_compile_env.sh to compute MAX_JOBS and NVCC_THREADS
    • Applies the custom patch and interface rename inside the Flash-Attention submodule
    • Runs python setup.py bdist_wheel under flash-attention/hopper
  4. Install the generated wheel (example):

    pip install flash-attention/hopper/dist/fa3_fwd-0.0.1-cp*-linux_x86_64.whl
    

Python Usage Example

import torch
from fa3_fwd_interface import flash_attn_func

# Inputs must already live on CUDA and satisfy Flash-Attention-3 constraints
out = flash_attn_func(q, k, v, causal=True)

This package exposes only the forward kernel. For backward support or additional features, depend on the upstream Flash-Attention project instead.

Troubleshooting

  • Out-of-memory during compilation: The build script already throttles concurrency, but you can enforce MAX_JOBS=1 NVCC_THREADS=1 before running bash build_fa3.sh.
  • CUDA mismatch errors: Confirm that nvcc --version aligns with torch.version.cuda.

Repository Layout

Customize further by editing environment variables in the build script or modifying the submodule before the patch is applied (for example to re-enable additional datatypes or kernels).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

fa3_fwd-0.0.2-cp39-abi3-manylinux_2_24_x86_64.whl (27.4 MB view details)

Uploaded CPython 3.9+manylinux: glibc 2.24+ x86-64

fa3_fwd-0.0.2-cp39-abi3-manylinux_2_24_aarch64.whl (27.9 MB view details)

Uploaded CPython 3.9+manylinux: glibc 2.24+ ARM64

File details

Details for the file fa3_fwd-0.0.2-cp39-abi3-manylinux_2_24_x86_64.whl.

File metadata

File hashes

Hashes for fa3_fwd-0.0.2-cp39-abi3-manylinux_2_24_x86_64.whl
Algorithm Hash digest
SHA256 1fddac2fb32bcb446ebef7359387e863b46323961e53c4975038eebb8b9f2ceb
MD5 1d005747183e18100f6208adbe56726b
BLAKE2b-256 9548ad4d2a53235738b2395546eab65c547de0523e450dc0a7e5f04b9d1e17a3

See more details on using hashes here.

File details

Details for the file fa3_fwd-0.0.2-cp39-abi3-manylinux_2_24_aarch64.whl.

File metadata

File hashes

Hashes for fa3_fwd-0.0.2-cp39-abi3-manylinux_2_24_aarch64.whl
Algorithm Hash digest
SHA256 869225e83021a42e6af1694ffb9f13755beb03829d469745d21f19e7d2cee2fc
MD5 4eec46660a2f49f7ceb95a925a29d693
BLAKE2b-256 418adc89594b1a332c1f421a67e235dc956f116c7241fbd2c04bbbee96ad4f9f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page